diff --git a/.rat-excludes b/.rat-excludes
index 236c2db05367c..9165872b9fb27 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -93,3 +93,5 @@ INDEX
.lintr
gen-java.*
.*avpr
+org.apache.spark.sql.sources.DataSourceRegister
+.*parquet
diff --git a/R/install-dev.bat b/R/install-dev.bat
index f32670b67de96..008a5c668bc45 100644
--- a/R/install-dev.bat
+++ b/R/install-dev.bat
@@ -25,8 +25,3 @@ set SPARK_HOME=%~dp0..
MKDIR %SPARK_HOME%\R\lib
R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
-
-rem Zip the SparkR package so that it can be distributed to worker nodes on YARN
-pushd %SPARK_HOME%\R\lib
-%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR
-popd
diff --git a/R/install-dev.sh b/R/install-dev.sh
index 4972bb9217072..59d98c9c7a646 100755
--- a/R/install-dev.sh
+++ b/R/install-dev.sh
@@ -42,8 +42,4 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo
# Install SparkR to $LIB_DIR
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
-# Zip the SparkR package so that it can be distributed to worker nodes on YARN
-cd $LIB_DIR
-jar cfM "$LIB_DIR/sparkr.zip" SparkR
-
popd > /dev/null
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 4949d86d20c91..83e64897216b1 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -29,6 +29,7 @@ Collate:
'client.R'
'context.R'
'deserialize.R'
+ 'functions.R'
'mllib.R'
'serialize.R'
'sparkR.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 7f7a8a2e4de24..b2d92bdf4840e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -12,7 +12,8 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
- "predict")
+ "predict",
+ "summary")
# Job group lifecycle management methods
export("setJobGroup",
@@ -28,6 +29,7 @@ exportMethods("arrange",
"count",
"crosstab",
"describe",
+ "dim",
"distinct",
"dropna",
"dtypes",
@@ -44,11 +46,16 @@ exportMethods("arrange",
"isLocal",
"join",
"limit",
+ "merge",
+ "names",
+ "ncol",
+ "nrow",
"orderBy",
"mutate",
"names",
"persist",
"printSchema",
+ "rbind",
"registerTempTable",
"rename",
"repartition",
@@ -63,8 +70,10 @@ exportMethods("arrange",
"show",
"showDF",
"summarize",
+ "summary",
"take",
"unionAll",
+ "unique",
"unpersist",
"where",
"withColumn",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index f4c93d3c7dd67..895603235011e 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -255,6 +255,16 @@ setMethod("names",
columns(x)
})
+#' @rdname columns
+setMethod("names<-",
+ signature(x = "DataFrame"),
+ function(x, value) {
+ if (!is.null(value)) {
+ sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value)))
+ dataFrame(sdf)
+ }
+ })
+
#' Register Temporary Table
#'
#' Registers a DataFrame as a Temporary Table in the SQLContext
@@ -473,6 +483,18 @@ setMethod("distinct",
dataFrame(sdf)
})
+#' @title Distinct rows in a DataFrame
+#
+#' @description Returns a new DataFrame containing distinct rows in this DataFrame
+#'
+#' @rdname unique
+#' @aliases unique
+setMethod("unique",
+ signature(x = "DataFrame"),
+ function(x) {
+ distinct(x)
+ })
+
#' Sample
#'
#' Return a sampled subset of this DataFrame using a random seed.
@@ -534,6 +556,58 @@ setMethod("count",
callJMethod(x@sdf, "count")
})
+#' @title Number of rows for a DataFrame
+#' @description Returns number of rows in a DataFrames
+#'
+#' @name nrow
+#'
+#' @rdname nrow
+#' @aliases count
+setMethod("nrow",
+ signature(x = "DataFrame"),
+ function(x) {
+ count(x)
+ })
+
+#' Returns the number of columns in a DataFrame
+#'
+#' @param x a SparkSQL DataFrame
+#'
+#' @rdname ncol
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlContext, path)
+#' ncol(df)
+#' }
+setMethod("ncol",
+ signature(x = "DataFrame"),
+ function(x) {
+ length(columns(x))
+ })
+
+#' Returns the dimentions (number of rows and columns) of a DataFrame
+#' @param x a SparkSQL DataFrame
+#'
+#' @rdname dim
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlContext, path)
+#' dim(df)
+#' }
+setMethod("dim",
+ signature(x = "DataFrame"),
+ function(x) {
+ c(count(x), ncol(x))
+ })
+
#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame.
#'
#' @param x A SparkSQL DataFrame
@@ -1205,6 +1279,15 @@ setMethod("join",
dataFrame(sdf)
})
+#' rdname merge
+#' aliases join
+setMethod("merge",
+ signature(x = "DataFrame", y = "DataFrame"),
+ function(x, y, joinExpr = NULL, joinType = NULL, ...) {
+ join(x, y, joinExpr, joinType)
+ })
+
+
#' UnionAll
#'
#' Return a new DataFrame containing the union of rows in this DataFrame
@@ -1231,6 +1314,22 @@ setMethod("unionAll",
dataFrame(unioned)
})
+#' @title Union two or more DataFrames
+#
+#' @description Returns a new DataFrame containing rows of all parameters.
+#
+#' @rdname rbind
+#' @aliases unionAll
+setMethod("rbind",
+ signature(... = "DataFrame"),
+ function(x, ..., deparse.level = 1) {
+ if (nargs() == 3) {
+ unionAll(x, ...)
+ } else {
+ unionAll(x, Recall(..., deparse.level = 1))
+ }
+ })
+
#' Intersect
#'
#' Return a new DataFrame containing rows only in both this DataFrame
@@ -1322,9 +1421,11 @@ setMethod("write.df",
"org.apache.spark.sql.parquet")
}
allModes <- c("append", "overwrite", "error", "ignore")
+ # nolint start
if (!(mode %in% allModes)) {
stop('mode should be one of "append", "overwrite", "error", "ignore"')
}
+ # nolint end
jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
options <- varargsToEnv(...)
if (!is.null(path)) {
@@ -1384,9 +1485,11 @@ setMethod("saveAsTable",
"org.apache.spark.sql.parquet")
}
allModes <- c("append", "overwrite", "error", "ignore")
+ # nolint start
if (!(mode %in% allModes)) {
stop('mode should be one of "append", "overwrite", "error", "ignore"')
}
+ # nolint end
jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
options <- varargsToEnv(...)
callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options)
@@ -1430,6 +1533,19 @@ setMethod("describe",
dataFrame(sdf)
})
+#' @title Summary
+#'
+#' @description Computes statistics for numeric columns of the DataFrame
+#'
+#' @rdname summary
+#' @aliases describe
+setMethod("summary",
+ signature(x = "DataFrame"),
+ function(x) {
+ describe(x)
+ })
+
+
#' dropna
#'
#' Returns a new DataFrame omitting rows with null values.
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index d2d096709245d..051e441d4e063 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -85,7 +85,9 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val)
isPipelinable <- function(rdd) {
e <- rdd@env
+ # nolint start
!(e$isCached || e$isCheckpointed)
+ # nolint end
}
if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) {
@@ -97,7 +99,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val)
# prev_serializedMode is used during the delayed computation of JRDD in getJRDD
} else {
pipelinedFunc <- function(partIndex, part) {
- func(partIndex, prev@func(partIndex, part))
+ f <- prev@func
+ func(partIndex, f(partIndex, part))
}
.Object@func <- cleanClosure(pipelinedFunc)
.Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline
@@ -841,7 +844,7 @@ setMethod("sampleRDD",
if (withReplacement) {
count <- rpois(1, fraction)
if (count > 0) {
- res[(len + 1):(len + count)] <- rep(list(elem), count)
+ res[ (len + 1) : (len + count) ] <- rep(list(elem), count)
len <- len + count
}
} else {
@@ -1261,12 +1264,12 @@ setMethod("pipeRDD",
signature(x = "RDD", command = "character"),
function(x, command, env = list()) {
func <- function(part) {
- trim.trailing.func <- function(x) {
+ trim_trailing_func <- function(x) {
sub("[\r\n]*$", "", toString(x))
}
- input <- unlist(lapply(part, trim.trailing.func))
+ input <- unlist(lapply(part, trim_trailing_func))
res <- system2(command, stdout = TRUE, input = input, env = env)
- lapply(res, trim.trailing.func)
+ lapply(res, trim_trailing_func)
}
lapplyPartition(x, func)
})
diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R
index 2fb6fae55f28c..49162838b8d1a 100644
--- a/R/pkg/R/backend.R
+++ b/R/pkg/R/backend.R
@@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) {
# TODO: check the status code to output error information
returnStatus <- readInt(conn)
- stopifnot(returnStatus == 0)
+ if (returnStatus != 0) {
+ stop(readString(conn))
+ }
readObject(conn)
}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 2892e1416cc65..328f595d0805f 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -60,12 +60,6 @@ operators <- list(
)
column_functions1 <- c("asc", "desc", "isNull", "isNotNull")
column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains")
-functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt",
- "first", "last", "lower", "upper", "sumDistinct",
- "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp",
- "expm1", "floor", "log", "log10", "log1p", "rint", "sign",
- "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians")
-binary_mathfunctions<- c("atan2", "hypot")
createOperator <- function(op) {
setMethod(op,
@@ -111,33 +105,6 @@ createColumnFunction2 <- function(name) {
})
}
-createStaticFunction <- function(name) {
- setMethod(name,
- signature(x = "Column"),
- function(x) {
- if (name == "ceiling") {
- name <- "ceil"
- }
- if (name == "sign") {
- name <- "signum"
- }
- jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc)
- column(jc)
- })
-}
-
-createBinaryMathfunctions <- function(name) {
- setMethod(name,
- signature(y = "Column"),
- function(y, x) {
- if (class(x) == "Column") {
- x <- x@jc
- }
- jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x)
- column(jc)
- })
-}
-
createMethods <- function() {
for (op in names(operators)) {
createOperator(op)
@@ -148,12 +115,6 @@ createMethods <- function() {
for (name in column_functions2) {
createColumnFunction2(name)
}
- for (x in functions) {
- createStaticFunction(x)
- }
- for (name in binary_mathfunctions) {
- createBinaryMathfunctions(name)
- }
}
createMethods()
@@ -242,45 +203,3 @@ setMethod("%in%",
jc <- callJMethod(x@jc, "in", table)
return(column(jc))
})
-
-#' Approx Count Distinct
-#'
-#' @rdname column
-#' @return the approximate number of distinct items in a group.
-setMethod("approxCountDistinct",
- signature(x = "Column"),
- function(x, rsd = 0.95) {
- jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd)
- column(jc)
- })
-
-#' Count Distinct
-#'
-#' @rdname column
-#' @return the number of distinct items in a group.
-setMethod("countDistinct",
- signature(x = "Column"),
- function(x, ...) {
- jcol <- lapply(list(...), function (x) {
- x@jc
- })
- jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
- listToSeq(jcol))
- column(jc)
- })
-
-#' @rdname column
-#' @aliases countDistinct
-setMethod("n_distinct",
- signature(x = "Column"),
- function(x, ...) {
- countDistinct(x, ...)
- })
-
-#' @rdname column
-#' @aliases count
-setMethod("n",
- signature(x = "Column"),
- function(x) {
- count(x)
- })
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 43be9c904fdf6..720990e1c6087 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -121,7 +121,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
numSlices <- length(coll)
sliceLen <- ceiling(length(coll) / numSlices)
- slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)])
+ slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)])
# Serialize each slice: obtain a list of raws, or a list of lists (slices) of
# 2-tuples of raws
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
new file mode 100644
index 0000000000000..a15d2d5da534e
--- /dev/null
+++ b/R/pkg/R/functions.R
@@ -0,0 +1,123 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#' @include generics.R column.R
+NULL
+
+#' @title S4 expression functions for DataFrame column(s)
+#' @description These are expression functions on DataFrame columns
+
+functions1 <- c(
+ "abs", "acos", "approxCountDistinct", "ascii", "asin", "atan",
+ "avg", "base64", "bin", "bitwiseNOT", "cbrt", "ceil", "cos", "cosh", "count",
+ "crc32", "dayofmonth", "dayofyear", "exp", "explode", "expm1", "factorial",
+ "first", "floor", "hex", "hour", "initcap", "isNaN", "last", "last_day",
+ "length", "log", "log10", "log1p", "log2", "lower", "ltrim", "max", "md5",
+ "mean", "min", "minute", "month", "negate", "quarter", "reverse",
+ "rint", "round", "rtrim", "second", "sha1", "signum", "sin", "sinh", "size",
+ "soundex", "sqrt", "sum", "sumDistinct", "tan", "tanh", "toDegrees",
+ "toRadians", "to_date", "trim", "unbase64", "unhex", "upper", "weekofyear",
+ "year")
+functions2 <- c(
+ "atan2", "datediff", "hypot", "levenshtein", "months_between", "nanvl", "pmod")
+
+createFunction1 <- function(name) {
+ setMethod(name,
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc)
+ column(jc)
+ })
+}
+
+createFunction2 <- function(name) {
+ setMethod(name,
+ signature(y = "Column"),
+ function(y, x) {
+ if (class(x) == "Column") {
+ x <- x@jc
+ }
+ jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x)
+ column(jc)
+ })
+}
+
+createFunctions <- function() {
+ for (name in functions1) {
+ createFunction1(name)
+ }
+ for (name in functions2) {
+ createFunction2(name)
+ }
+}
+
+createFunctions()
+
+#' Approx Count Distinct
+#'
+#' @rdname functions
+#' @return the approximate number of distinct items in a group.
+setMethod("approxCountDistinct",
+ signature(x = "Column"),
+ function(x, rsd = 0.95) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd)
+ column(jc)
+ })
+
+#' Count Distinct
+#'
+#' @rdname functions
+#' @return the number of distinct items in a group.
+setMethod("countDistinct",
+ signature(x = "Column"),
+ function(x, ...) {
+ jcol <- lapply(list(...), function (x) {
+ x@jc
+ })
+ jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
+ listToSeq(jcol))
+ column(jc)
+ })
+
+#' @rdname functions
+#' @aliases ceil
+setMethod("ceiling",
+ signature(x = "Column"),
+ function(x) {
+ ceil(x)
+ })
+
+#' @rdname functions
+#' @aliases signum
+setMethod("sign", signature(x = "Column"),
+ function(x) {
+ signum(x)
+ })
+
+#' @rdname functions
+#' @aliases countDistinct
+setMethod("n_distinct", signature(x = "Column"),
+ function(x, ...) {
+ countDistinct(x, ...)
+ })
+
+#' @rdname functions
+#' @aliases count
+setMethod("n", signature(x = "Column"),
+ function(x) {
+ count(x)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index a3a121058e165..f11e7fcb6a07c 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -461,6 +461,10 @@ setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
#' @export
setGeneric("limit", function(x, num) {standardGeneric("limit") })
+#' rdname merge
+#' @export
+setGeneric("merge")
+
#' @rdname withColumn
#' @export
setGeneric("mutate", function(x, ...) {standardGeneric("mutate") })
@@ -531,6 +535,10 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
#' @export
setGeneric("summarize", function(x,...) { standardGeneric("summarize") })
+#' @rdname summary
+#' @export
+setGeneric("summary", function(x, ...) { standardGeneric("summary") })
+
# @rdname tojson
# @export
setGeneric("toJSON", function(x) { standardGeneric("toJSON") })
@@ -567,10 +575,6 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun
#' @export
setGeneric("asc", function(x) { standardGeneric("asc") })
-#' @rdname column
-#' @export
-setGeneric("avg", function(x, ...) { standardGeneric("avg") })
-
#' @rdname column
#' @export
setGeneric("between", function(x, bounds) { standardGeneric("between") })
@@ -579,13 +583,10 @@ setGeneric("between", function(x, bounds) { standardGeneric("between") })
#' @export
setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
-#' @rdname column
-#' @export
-setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
-
#' @rdname column
#' @export
setGeneric("contains", function(x, ...) { standardGeneric("contains") })
+
#' @rdname column
#' @export
setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") })
@@ -650,22 +651,194 @@ setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
#' @export
setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") })
-#' @rdname column
+
+###################### Expression Function Methods ##########################
+
+#' @rdname functions
+#' @export
+setGeneric("ascii", function(x) { standardGeneric("ascii") })
+
+#' @rdname functions
+#' @export
+setGeneric("avg", function(x, ...) { standardGeneric("avg") })
+
+#' @rdname functions
+#' @export
+setGeneric("base64", function(x) { standardGeneric("base64") })
+
+#' @rdname functions
+#' @export
+setGeneric("bin", function(x) { standardGeneric("bin") })
+
+#' @rdname functions
+#' @export
+setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") })
+
+#' @rdname functions
+#' @export
+setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
+
+#' @rdname functions
+#' @export
+setGeneric("ceil", function(x) { standardGeneric("ceil") })
+
+#' @rdname functions
+#' @export
+setGeneric("crc32", function(x) { standardGeneric("crc32") })
+
+#' @rdname functions
+#' @export
+setGeneric("datediff", function(y, x) { standardGeneric("datediff") })
+
+#' @rdname functions
+#' @export
+setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") })
+
+#' @rdname functions
+#' @export
+setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
+
+#' @rdname functions
+#' @export
+setGeneric("explode", function(x) { standardGeneric("explode") })
+
+#' @rdname functions
+#' @export
+setGeneric("hex", function(x) { standardGeneric("hex") })
+
+#' @rdname functions
+#' @export
+setGeneric("hour", function(x) { standardGeneric("hour") })
+
+#' @rdname functions
+#' @export
+setGeneric("initcap", function(x) { standardGeneric("initcap") })
+
+#' @rdname functions
+#' @export
+setGeneric("isNaN", function(x) { standardGeneric("isNaN") })
+
+#' @rdname functions
+#' @export
+setGeneric("last_day", function(x) { standardGeneric("last_day") })
+
+#' @rdname functions
+#' @export
+setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") })
+
+#' @rdname functions
+#' @export
+setGeneric("lower", function(x) { standardGeneric("lower") })
+
+#' @rdname functions
+#' @export
+setGeneric("ltrim", function(x) { standardGeneric("ltrim") })
+
+#' @rdname functions
+#' @export
+setGeneric("md5", function(x) { standardGeneric("md5") })
+
+#' @rdname functions
+#' @export
+setGeneric("minute", function(x) { standardGeneric("minute") })
+
+#' @rdname functions
+#' @export
+setGeneric("month", function(x) { standardGeneric("month") })
+
+#' @rdname functions
+#' @export
+setGeneric("months_between", function(y, x) { standardGeneric("months_between") })
+
+#' @rdname functions
+#' @export
+setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") })
+
+#' @rdname functions
+#' @export
+setGeneric("negate", function(x) { standardGeneric("negate") })
+
+#' @rdname functions
+#' @export
+setGeneric("pmod", function(y, x) { standardGeneric("pmod") })
+
+#' @rdname functions
+#' @export
+setGeneric("quarter", function(x) { standardGeneric("quarter") })
+
+#' @rdname functions
+#' @export
+setGeneric("reverse", function(x) { standardGeneric("reverse") })
+
+#' @rdname functions
+#' @export
+setGeneric("rtrim", function(x) { standardGeneric("rtrim") })
+
+#' @rdname functions
+#' @export
+setGeneric("second", function(x) { standardGeneric("second") })
+
+#' @rdname functions
+#' @export
+setGeneric("sha1", function(x) { standardGeneric("sha1") })
+
+#' @rdname functions
+#' @export
+setGeneric("signum", function(x) { standardGeneric("signum") })
+
+#' @rdname functions
+#' @export
+setGeneric("size", function(x) { standardGeneric("size") })
+
+#' @rdname functions
+#' @export
+setGeneric("soundex", function(x) { standardGeneric("soundex") })
+
+#' @rdname functions
#' @export
setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })
-#' @rdname column
+#' @rdname functions
#' @export
setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") })
-#' @rdname column
+#' @rdname functions
#' @export
setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
-#' @rdname column
+#' @rdname functions
+#' @export
+setGeneric("to_date", function(x) { standardGeneric("to_date") })
+
+#' @rdname functions
+#' @export
+setGeneric("trim", function(x) { standardGeneric("trim") })
+
+#' @rdname functions
+#' @export
+setGeneric("unbase64", function(x) { standardGeneric("unbase64") })
+
+#' @rdname functions
+#' @export
+setGeneric("unhex", function(x) { standardGeneric("unhex") })
+
+#' @rdname functions
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
+#' @rdname functions
+#' @export
+setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") })
+
+#' @rdname functions
+#' @export
+setGeneric("year", function(x) { standardGeneric("year") })
+
+
#' @rdname glm
#' @export
setGeneric("glm")
+
+#' @rdname rbind
+#' @export
+setGeneric("rbind", signature = "...")
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 6a8bacaa552c6..cea3d760d05fe 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -56,10 +56,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
#'
#' Makes predictions from a model produced by glm(), similarly to R's predict().
#'
-#' @param model A fitted MLlib model
+#' @param object A fitted MLlib model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted values
-#' @rdname glm
+#' @rdname predict
#' @export
#' @examples
#'\dontrun{
@@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
+
+#' Get the summary of a model
+#'
+#' Returns the summary of a model produced by glm(), similarly to R's summary().
+#'
+#' @param x A fitted MLlib model
+#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See
+#' summary.glm for more information.
+#' @rdname summary
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' summary(model)
+#'}
+setMethod("summary", signature(x = "PipelineModel"),
+ function(x, ...) {
+ features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelFeatures", x@model)
+ weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelWeights", x@model)
+ coefficients <- as.matrix(unlist(weights))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ })
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 83801d3209700..199c3fd6ab1b2 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -879,7 +879,7 @@ setMethod("sampleByKey",
if (withReplacement) {
count <- rpois(1, frac)
if (count > 0) {
- res[(len + 1):(len + count)] <- rep(list(elem), count)
+ res[ (len + 1) : (len + count) ] <- rep(list(elem), count)
len <- len + count
}
} else {
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 3f45589a50443..4f9f4d9cad2a8 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -32,7 +32,7 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL,
}
results <- if (arrSize > 0) {
- lapply(0:(arrSize - 1),
+ lapply(0 : (arrSize - 1),
function(index) {
obj <- callJMethod(jList, "get", as.integer(index))
@@ -572,7 +572,7 @@ mergePartitions <- function(rdd, zip) {
keys <- list()
}
if (lengthOfValues > 1) {
- values <- part[(lengthOfKeys + 1) : (len - 1)]
+ values <- part[ (lengthOfKeys + 1) : (len - 1) ]
} else {
values <- list()
}
diff --git a/R/pkg/inst/tests/packageInAJarTest.R b/R/pkg/inst/tests/packageInAJarTest.R
new file mode 100644
index 0000000000000..207a37a0cb47f
--- /dev/null
+++ b/R/pkg/inst/tests/packageInAJarTest.R
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+library(SparkR)
+library(sparkPackageTest)
+
+sc <- sparkR.init()
+
+run1 <- myfunc(5L)
+
+run2 <- myfunc(-4L)
+
+sparkR.stop()
+
+if(run1 != 6) quit(save = "no", status = 1)
+
+if(run2 != -3) quit(save = "no", status = 1)
diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R
index dca0657c57e0d..f054ac9a87d61 100644
--- a/R/pkg/inst/tests/test_binary_function.R
+++ b/R/pkg/inst/tests/test_binary_function.R
@@ -40,7 +40,7 @@ test_that("union on two RDDs", {
expect_equal(actual, c(as.list(nums), mockFile))
expect_equal(getSerializedMode(union.rdd), "byte")
- rdd<- map(text.rdd, function(x) {x})
+ rdd <- map(text.rdd, function(x) {x})
union.rdd <- unionRDD(rdd, text.rdd)
actual <- collect(union.rdd)
expect_equal(actual, as.list(c(mockFile, mockFile)))
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index 3bef69324770a..f272de78ad4a6 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", {
rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
+
+test_that("summary coefficients match with native glm", {
+ training <- createDataFrame(sqlContext, iris)
+ stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
+ coefs <- as.vector(stats$coefficients)
+ rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
+ expect_true(all(abs(rCoefs - coefs) < 1e-6))
+ expect_true(all(
+ as.character(stats$features) ==
+ c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica")))
+})
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index 6c3aaab8c711e..71aed2bb9d6a8 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -250,7 +250,7 @@ test_that("flatMapValues() on pairwise RDDs", {
expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4)))
# Generate x to x+1 for every value
- actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) }))
+ actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) }))
expect_equal(actual,
list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101),
list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201)))
@@ -293,7 +293,7 @@ test_that("sumRDD() on RDDs", {
})
test_that("keyBy on RDDs", {
- func <- function(x) { x*x }
+ func <- function(x) { x * x }
keys <- keyBy(rdd, func)
actual <- collect(keys)
expect_equal(actual, lapply(nums, function(x) { list(func(x), x) }))
@@ -311,7 +311,7 @@ test_that("repartition/coalesce on RDDs", {
r2 <- repartition(rdd, 6)
expect_equal(numPartitions(r2), 6L)
count <- length(collectPartition(r2, 0L))
- expect_true(count >=0 && count <= 4)
+ expect_true(count >= 0 && count <= 4)
# coalesce
r3 <- coalesce(rdd, 1)
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index d5db97248c770..e6d3b21ff825b 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -88,6 +88,9 @@ test_that("create DataFrame from RDD", {
df <- createDataFrame(sqlContext, rdd, list("a", "b"))
expect_is(df, "DataFrame")
expect_equal(count(df), 10)
+ expect_equal(nrow(df), 10)
+ expect_equal(ncol(df), 2)
+ expect_equal(dim(df), c(10, 2))
expect_equal(columns(df), c("a", "b"))
expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
@@ -128,7 +131,9 @@ test_that("create DataFrame from RDD", {
expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float")))
expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5))
- localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7))
+ localDF <- data.frame(name=c("John", "Smith", "Sarah"),
+ age=c(19, 23, 18),
+ height=c(164.10, 181.4, 173.7))
df <- createDataFrame(sqlContext, localDF, schema)
expect_is(df, "DataFrame")
expect_equal(count(df), 3)
@@ -489,7 +494,7 @@ test_that("head() and first() return the correct data", {
expect_equal(nrow(testFirst), 1)
})
-test_that("distinct() on DataFrames", {
+test_that("distinct() and unique on DataFrames", {
lines <- c("{\"name\":\"Michael\"}",
"{\"name\":\"Andy\", \"age\":30}",
"{\"name\":\"Justin\", \"age\":19}",
@@ -501,6 +506,10 @@ test_that("distinct() on DataFrames", {
uniques <- distinct(df)
expect_is(uniques, "DataFrame")
expect_equal(count(uniques), 3)
+
+ uniques2 <- unique(df)
+ expect_is(uniques2, "DataFrame")
+ expect_equal(count(uniques2), 3)
})
test_that("sample on a DataFrame", {
@@ -631,15 +640,18 @@ test_that("column operators", {
test_that("column functions", {
c <- SparkR:::col("a")
- c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
- c3 <- lower(c) + upper(c) + first(c) + last(c)
- c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
- c5 <- n(c) + n_distinct(c)
- c5 <- acos(c) + asin(c) + atan(c) + cbrt(c)
- c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c)
- c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
- c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
- c9 <- toDegrees(c) + toRadians(c)
+ c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c)
+ c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c)
+ c3 <- cosh(c) + count(c) + crc32(c) + dayofmonth(c) + dayofyear(c) + exp(c)
+ c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
+ c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c)
+ c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
+ c7 <- mean(c) + min(c) + minute(c) + month(c) + negate(c) + quarter(c)
+ c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + second(c) + sha1(c)
+ c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c)
+ c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
+ c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + weekofyear(c)
+ c12 <- year(c)
df <- jsonFile(sqlContext, jsonPath)
df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20)))
@@ -666,10 +678,12 @@ test_that("column binary mathfunctions", {
expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6))
expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7))
expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8))
+ ## nolint start
expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2))
expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2))
+ ## nolint end
})
test_that("string operators", {
@@ -754,7 +768,7 @@ test_that("filter() on a DataFrame", {
expect_equal(count(filtered6), 2)
})
-test_that("join() on a DataFrame", {
+test_that("join() and merge() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}",
@@ -783,6 +797,12 @@ test_that("join() on a DataFrame", {
expect_equal(names(joined4), c("newAge", "name", "test"))
expect_equal(count(joined4), 4)
expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24)
+
+ merged <- select(merge(df, df2, df$name == df2$name, "outer"),
+ alias(df$age + 5, "newAge"), df$name, df2$test)
+ expect_equal(names(merged), c("newAge", "name", "test"))
+ expect_equal(count(merged), 4)
+ expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24)
})
test_that("toJSON() returns an RDD of the correct values", {
@@ -811,7 +831,7 @@ test_that("isLocal()", {
expect_false(isLocal(df))
})
-test_that("unionAll(), except(), and intersect() on a DataFrame", {
+test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
lines <- c("{\"name\":\"Bob\", \"age\":24}",
@@ -826,6 +846,11 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", {
expect_equal(count(unioned), 6)
expect_equal(first(unioned)$name, "Michael")
+ unioned2 <- arrange(rbind(unioned, df, df2), df$age)
+ expect_is(unioned2, "DataFrame")
+ expect_equal(count(unioned2), 12)
+ expect_equal(first(unioned2)$name, "Michael")
+
excepted <- arrange(except(df, df2), desc(df$age))
expect_is(unioned, "DataFrame")
expect_equal(count(excepted), 2)
@@ -849,7 +874,7 @@ test_that("withColumn() and withColumnRenamed()", {
expect_equal(columns(newDF2)[1], "newerAge")
})
-test_that("mutate() and rename()", {
+test_that("mutate(), rename() and names()", {
df <- jsonFile(sqlContext, jsonPath)
newDF <- mutate(df, newAge = df$age + 2)
expect_equal(length(columns(newDF)), 3)
@@ -859,6 +884,10 @@ test_that("mutate() and rename()", {
newDF2 <- rename(df, newerAge = df$age)
expect_equal(length(columns(newDF2)), 2)
expect_equal(columns(newDF2)[1], "newerAge")
+
+ names(newDF2) <- c("newerName", "evenNewerAge")
+ expect_equal(length(names(newDF2)), 2)
+ expect_equal(names(newDF2)[1], "newerName")
})
test_that("write.df() on DataFrame and works with parquetFile", {
@@ -876,10 +905,10 @@ test_that("parquetFile works with multiple input paths", {
write.df(df, parquetPath2, "parquet", mode="overwrite")
parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2)
expect_is(parquetDF, "DataFrame")
- expect_equal(count(parquetDF), count(df)*2)
+ expect_equal(count(parquetDF), count(df) * 2)
})
-test_that("describe() on a DataFrame", {
+test_that("describe() and summarize() on a DataFrame", {
df <- jsonFile(sqlContext, jsonPath)
stats <- describe(df, "age")
expect_equal(collect(stats)[1, "summary"], "count")
@@ -888,6 +917,10 @@ test_that("describe() on a DataFrame", {
stats <- describe(df)
expect_equal(collect(stats)[4, "name"], "Andy")
expect_equal(collect(stats)[5, "age"], "30")
+
+ stats2 <- summary(df)
+ expect_equal(collect(stats2)[4, "name"], "Andy")
+ expect_equal(collect(stats2)[5, "age"], "30")
})
test_that("dropna() on a DataFrame", {
@@ -1002,6 +1035,11 @@ test_that("crosstab() on a DataFrame", {
expect_identical(expected, ordered)
})
+test_that("SQL error message is returned from JVM", {
+ retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
+ expect_equal(grepl("Table Not Found: blah", retError), TRUE)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
diff --git a/R/run-tests.sh b/R/run-tests.sh
index 18a1e13bdc655..e82ad0ba2cd06 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE
-SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))
if [[ $FAILED != 0 ]]; then
diff --git a/build/mvn b/build/mvn
index f62f61ee1c416..ec0380afad319 100755
--- a/build/mvn
+++ b/build/mvn
@@ -51,11 +51,11 @@ install_app() {
# check if we have curl installed
# download application
[ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \
- echo "exec: curl ${curl_opts} ${remote_tarball}" && \
+ echo "exec: curl ${curl_opts} ${remote_tarball}" 1>&2 && \
curl ${curl_opts} "${remote_tarball}" > "${local_tarball}"
# if the file still doesn't exist, lets try `wget` and cross our fingers
[ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \
- echo "exec: wget ${wget_opts} ${remote_tarball}" && \
+ echo "exec: wget ${wget_opts} ${remote_tarball}" 1>&2 && \
wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}"
# if both were unsuccessful, exit
[ ! -f "${local_tarball}" ] && \
@@ -82,7 +82,7 @@ install_mvn() {
# Install zinc under the build/ folder
install_zinc() {
local zinc_path="zinc-0.3.5.3/bin/zinc"
- [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1
+ [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1
install_app \
"http://downloads.typesafe.com/zinc/0.3.5.3" \
"zinc-0.3.5.3.tgz" \
@@ -135,9 +135,9 @@ cd "${_CALLING_DIR}"
# Now that zinc is ensured to be installed, check its status and, if its
# not running or just installed, start it
-if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then
+if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status -port ${ZINC_PORT}`" ]; then
export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"}
- ${ZINC_BIN} -shutdown
+ ${ZINC_BIN} -shutdown -port ${ZINC_PORT}
${ZINC_BIN} -start -port ${ZINC_PORT} \
-scala-compiler "${SCALA_COMPILER}" \
-scala-library "${SCALA_LIBRARY}" &>/dev/null
@@ -146,7 +146,7 @@ fi
# Set any `mvn` options if not already present
export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
-echo "Using \`mvn\` from path: $MVN_BIN"
+echo "Using \`mvn\` from path: $MVN_BIN" 1>&2
# Last, call the `mvn` command as usual
-${MVN_BIN} "$@"
+${MVN_BIN} -DzincPort=${ZINC_PORT} "$@"
diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash
index 7930a38b9674a..615f848394650 100755
--- a/build/sbt-launch-lib.bash
+++ b/build/sbt-launch-lib.bash
@@ -38,8 +38,7 @@ dlog () {
acquire_sbt_jar () {
SBT_VERSION=`awk -F "=" '/sbt\.version/ {print $2}' ./project/build.properties`
- URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
- URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
+ URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar
JAR=build/sbt-launch-${SBT_VERSION}.jar
sbt_jar=$JAR
@@ -51,12 +50,10 @@ acquire_sbt_jar () {
printf "Attempting to fetch sbt\n"
JAR_DL="${JAR}.part"
if [ $(command -v curl) ]; then
- (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\
- (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\
+ curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\
mv "${JAR_DL}" "${JAR}"
elif [ $(command -v wget) ]; then
- (wget --quiet ${URL1} -O "${JAR_DL}" ||\
- (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\
+ wget --quiet ${URL1} -O "${JAR_DL}" &&\
mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 192d3ae091134..c05fe381a36a7 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -38,6 +38,7 @@
# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node
# - SPARK_WORKER_DIR, to set the working directory of worker processes
# - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y")
+# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g).
# - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y")
# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y")
# - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y")
diff --git a/core/pom.xml b/core/pom.xml
index 6fa87ec6a24af..0e53a79fd2235 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -46,30 +46,10 @@
com.twitterchill_${scala.binary.version}
-
-
- org.ow2.asm
- asm
-
-
- org.ow2.asm
- asm-commons
-
- com.twitterchill-java
-
-
- org.ow2.asm
- asm
-
-
- org.ow2.asm
- asm-commons
-
- org.apache.hadoop
@@ -286,7 +266,7 @@
org.tachyonprojecttachyon-client
- 0.6.4
+ 0.7.0org.apache.hadoop
@@ -297,36 +277,12 @@
curator-recipes
- org.eclipse.jetty
- jetty-jsp
+ org.tachyonproject
+ tachyon-underfs-glusterfs
- org.eclipse.jetty
- jetty-webapp
-
-
- org.eclipse.jetty
- jetty-server
-
-
- org.eclipse.jetty
- jetty-servlet
-
-
- junit
- junit
-
-
- org.powermock
- powermock-module-junit4
-
-
- org.powermock
- powermock-api-mockito
-
-
- org.apache.curator
- curator-test
+ org.tachyonproject
+ tachyon-underfs-s3
diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
index 2090efd3b9990..d4c42b38ac224 100644
--- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
+++ b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -23,11 +23,13 @@
// See
// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html
abstract class JavaSparkContextVarargsWorkaround {
- public JavaRDD union(JavaRDD... rdds) {
+
+ @SafeVarargs
+ public final JavaRDD union(JavaRDD... rdds) {
if (rdds.length == 0) {
throw new IllegalArgumentException("Union called on empty list");
}
- ArrayList> rest = new ArrayList>(rdds.length - 1);
+ List> rest = new ArrayList<>(rdds.length - 1);
for (int i = 1; i < rdds.length; i++) {
rest.add(rdds[i]);
}
@@ -38,18 +40,19 @@ public JavaDoubleRDD union(JavaDoubleRDD... rdds) {
if (rdds.length == 0) {
throw new IllegalArgumentException("Union called on empty list");
}
- ArrayList rest = new ArrayList(rdds.length - 1);
+ List rest = new ArrayList<>(rdds.length - 1);
for (int i = 1; i < rdds.length; i++) {
rest.add(rdds[i]);
}
return union(rdds[0], rest);
}
- public JavaPairRDD union(JavaPairRDD... rdds) {
+ @SafeVarargs
+ public final JavaPairRDD union(JavaPairRDD... rdds) {
if (rdds.length == 0) {
throw new IllegalArgumentException("Union called on empty list");
}
- ArrayList> rest = new ArrayList>(rdds.length - 1);
+ List> rest = new ArrayList<>(rdds.length - 1);
for (int i = 1; i < rdds.length; i++) {
rest.add(rdds[i]);
}
@@ -57,7 +60,7 @@ public JavaPairRDD union(JavaPairRDD... rdds) {
}
// These methods take separate "first" and "rest" elements to avoid having the same type erasure
- abstract public JavaRDD union(JavaRDD first, List> rest);
- abstract public JavaDoubleRDD union(JavaDoubleRDD first, List rest);
- abstract public JavaPairRDD union(JavaPairRDD first, List> rest);
+ public abstract JavaRDD union(JavaRDD first, List> rest);
+ public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest);
+ public abstract JavaPairRDD union(JavaPairRDD first, List> rest);
}
diff --git a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
index 0399abc63c235..0e58bb4f7101c 100644
--- a/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
+++ b/core/src/main/java/org/apache/spark/serializer/DummySerializerInstance.java
@@ -25,7 +25,7 @@
import scala.reflect.ClassTag;
import org.apache.spark.annotation.Private;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
/**
* Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
@@ -49,7 +49,7 @@ public void flush() {
try {
s.flush();
} catch (IOException e) {
- PlatformDependent.throwException(e);
+ Platform.throwException(e);
}
}
@@ -64,7 +64,7 @@ public void close() {
try {
s.close();
} catch (IOException e) {
- PlatformDependent.throwException(e);
+ Platform.throwException(e);
}
}
};
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1aa6ba4201261..3d1ef0c48adc5 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle.unsafe;
+import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
@@ -33,8 +34,11 @@
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.TempShuffleBlockId;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
@@ -67,7 +71,7 @@ final class UnsafeShuffleExternalSorter {
private final int pageSizeBytes;
@VisibleForTesting
final int maxRecordSizeBytes;
- private final TaskMemoryManager memoryManager;
+ private final TaskMemoryManager taskMemoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
@@ -86,9 +90,12 @@ final class UnsafeShuffleExternalSorter {
private final LinkedList spills = new LinkedList();
+ /** Peak memory used by this sorter so far, in bytes. **/
+ private long peakMemoryUsedBytes;
+
// These variables are reset after spilling:
- private UnsafeShuffleInMemorySorter sorter;
- private MemoryBlock currentPage = null;
+ @Nullable private UnsafeShuffleInMemorySorter inMemSorter;
+ @Nullable private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
@@ -101,17 +108,17 @@ public UnsafeShuffleExternalSorter(
int numPartitions,
SparkConf conf,
ShuffleWriteMetrics writeMetrics) throws IOException {
- this.memoryManager = memoryManager;
+ this.taskMemoryManager = memoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.initialSize = initialSize;
+ this.peakMemoryUsedBytes = initialSize;
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.pageSizeBytes = (int) Math.min(
- PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
- conf.getSizeAsBytes("spark.buffer.pageSize", "64m"));
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
this.writeMetrics = writeMetrics;
initializeForWriting();
@@ -129,7 +136,7 @@ private void initializeForWriting() throws IOException {
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}
- this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+ this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
}
/**
@@ -156,7 +163,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
// This call performs the actual sort.
final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
- sorter.getSortedIterator();
+ inMemSorter.getSortedIterator();
// Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
// after SPARK-5581 is fixed.
@@ -202,18 +209,14 @@ private void writeSortedFile(boolean isLastFile) throws IOException {
}
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
- final Object recordPage = memoryManager.getPage(recordPointer);
- final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
- int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
+ final Object recordPage = taskMemoryManager.getPage(recordPointer);
+ final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
+ int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
long recordReadPosition = recordOffsetInPage + 4; // skip over record length
while (dataRemaining > 0) {
final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
- PlatformDependent.copyMemory(
- recordPage,
- recordReadPosition,
- writeBuffer,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- toTransfer);
+ Platform.copyMemory(
+ recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
writer.write(writeBuffer, 0, toTransfer);
recordReadPosition += toTransfer;
dataRemaining -= toTransfer;
@@ -265,9 +268,9 @@ void spill() throws IOException {
spills.size() > 1 ? " times" : " time");
writeSortedFile(false);
- final long sorterMemoryUsage = sorter.getMemoryUsage();
- sorter = null;
- shuffleMemoryManager.release(sorterMemoryUsage);
+ final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ shuffleMemoryManager.release(inMemSorterMemoryUsage);
final long spillSize = freeMemory();
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
@@ -279,13 +282,29 @@ private long getMemoryUsage() {
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return sorter.getMemoryUsage() + totalPageSize;
+ return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
private long freeMemory() {
+ updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
- memoryManager.freePage(block);
+ taskMemoryManager.freePage(block);
shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
@@ -299,54 +318,53 @@ private long freeMemory() {
/**
* Force all memory and spill files to be deleted; called by shuffle error-handling code.
*/
- public void cleanupAfterError() {
+ public void cleanupResources() {
freeMemory();
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
logger.error("Unable to delete spill file {}", spill.file.getPath());
}
}
- if (sorter != null) {
- shuffleMemoryManager.release(sorter.getMemoryUsage());
- sorter = null;
+ if (inMemSorter != null) {
+ shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
+ inMemSorter = null;
}
}
/**
- * Checks whether there is enough space to insert a new record into the sorter.
- *
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
- * the record size.
-
- * @return true if the record can be inserted without requiring more allocations, false otherwise.
- */
- private boolean haveSpaceForRecord(int requiredSpace) {
- assert (requiredSpace > 0);
- return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
- }
-
- /**
- * Allocates more memory in order to insert an additional record. This will request additional
- * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
- * obtained.
- *
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
- * the record size.
+ * Checks whether there is enough space to insert an additional record in to the sort pointer
+ * array and grows the array if additional space is required. If the required space cannot be
+ * obtained, then the in-memory data will be spilled to disk.
*/
- private void allocateSpaceForRecord(int requiredSpace) throws IOException {
- if (!sorter.hasSpaceForAnotherRecord()) {
+ private void growPointerArrayIfNecessary() throws IOException {
+ assert(inMemSorter != null);
+ if (!inMemSorter.hasSpaceForAnotherRecord()) {
logger.debug("Attempting to expand sort pointer array");
- final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
if (memoryAcquired < memoryToGrowPointerArray) {
shuffleMemoryManager.release(memoryAcquired);
spill();
} else {
- sorter.expandPointerArray();
+ inMemSorter.expandPointerArray();
shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
}
}
+ }
+
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size. This must be less than or equal to the page size (records
+ * that exceed the page size are handled via a different code path which uses
+ * special overflow pages).
+ */
+ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
+ growPointerArrayIfNecessary();
if (requiredSpace > freeSpaceInCurrentPage) {
logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
freeSpaceInCurrentPage);
@@ -367,7 +385,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
- currentPage = memoryManager.allocatePage(pageSizeBytes);
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
@@ -383,27 +401,54 @@ public void insertRecord(
long recordBaseOffset,
int lengthInBytes,
int partitionId) throws IOException {
+
+ growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
- if (!haveSpaceForRecord(totalSpaceRequired)) {
- allocateSpaceForRecord(totalSpaceRequired);
+
+ // --- Figure out where to insert the new record ----------------------------------------------
+
+ final MemoryBlock dataPage;
+ long dataPagePosition;
+ boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+ if (useOverflowPage) {
+ long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGranted != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGranted);
+ spill();
+ final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGrantedAfterSpill != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+ }
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ allocatedPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPagePosition = overflowPage.getBaseOffset();
+ } else {
+ // The record is small enough to fit in a regular data page, but the current page might not
+ // have enough space to hold it (or no pages have been allocated yet).
+ acquireNewPageIfNecessary(totalSpaceRequired);
+ dataPage = currentPage;
+ dataPagePosition = currentPagePosition;
+ // Update bookkeeping information
+ freeSpaceInCurrentPage -= totalSpaceRequired;
+ currentPagePosition += totalSpaceRequired;
}
+ final Object dataPageBaseObject = dataPage.getBaseObject();
final long recordAddress =
- memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
- final Object dataPageBaseObject = currentPage.getBaseObject();
- PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
- currentPagePosition += 4;
- freeSpaceInCurrentPage -= 4;
- PlatformDependent.copyMemory(
- recordBaseObject,
- recordBaseOffset,
- dataPageBaseObject,
- currentPagePosition,
- lengthInBytes);
- currentPagePosition += lengthInBytes;
- freeSpaceInCurrentPage -= lengthInBytes;
- sorter.insertRecord(recordAddress, partitionId);
+ taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+ Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
+ dataPagePosition += 4;
+ Platform.copyMemory(
+ recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+ assert(inMemSorter != null);
+ inMemSorter.insertRecord(recordAddress, partitionId);
}
/**
@@ -415,14 +460,14 @@ public void insertRecord(
*/
public SpillInfo[] closeAndGetSpills() throws IOException {
try {
- if (sorter != null) {
+ if (inMemSorter != null) {
// Do not count the final file towards the spill count.
writeSortedFile(true);
freeMemory();
}
return spills.toArray(new SpillInfo[spills.size()]);
} catch (IOException e) {
- cleanupAfterError();
+ cleanupResources();
throw e;
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index d47d6fc9c2ac4..2389c28b28395 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -17,14 +17,15 @@
package org.apache.spark.shuffle.unsafe;
+import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
import java.util.Iterator;
-import javax.annotation.Nullable;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConversions;
+import scala.collection.immutable.Map;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
@@ -37,10 +38,10 @@
import org.apache.spark.*;
import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@@ -52,7 +53,7 @@
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
@Private
@@ -78,8 +79,9 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
private final SparkConf sparkConf;
private final boolean transferToEnabled;
- private MapStatus mapStatus = null;
- private UnsafeShuffleExternalSorter sorter = null;
+ @Nullable private MapStatus mapStatus;
+ @Nullable private UnsafeShuffleExternalSorter sorter;
+ private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
@@ -131,9 +133,28 @@ public UnsafeShuffleWriter(
@VisibleForTesting
public int maxRecordSizeBytes() {
+ assert(sorter != null);
return sorter.maxRecordSizeBytes;
}
+ private void updatePeakMemoryUsed() {
+ // sorter can be null if this writer is closed
+ if (sorter != null) {
+ long mem = sorter.getPeakMemoryUsedBytes();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
+ }
+
/**
* This convenience method should only be called in test code.
*/
@@ -144,7 +165,7 @@ public void write(Iterator> records) throws IOException {
@Override
public void write(scala.collection.Iterator> records) throws IOException {
- // Keep track of success so we know if we ecountered an exception
+ // Keep track of success so we know if we encountered an exception
// We do this rather than a standard try/catch/re-throw to handle
// generic throwables.
boolean success = false;
@@ -157,7 +178,7 @@ public void write(scala.collection.Iterator> records) throws IOEx
} finally {
if (sorter != null) {
try {
- sorter.cleanupAfterError();
+ sorter.cleanupResources();
} catch (Exception e) {
// Only throw this error if we won't be masking another
// error.
@@ -189,6 +210,8 @@ private void open() throws IOException {
@VisibleForTesting
void closeAndWriteOutput() throws IOException {
+ assert(sorter != null);
+ updatePeakMemoryUsed();
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
@@ -209,6 +232,7 @@ void closeAndWriteOutput() throws IOException {
@VisibleForTesting
void insertRecordIntoSorter(Product2 record) throws IOException {
+ assert(sorter != null);
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
@@ -220,7 +244,7 @@ void insertRecordIntoSorter(Product2 record) throws IOException {
assert (serializedRecordSize > 0);
sorter.insertRecord(
- serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+ serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}
@VisibleForTesting
@@ -431,6 +455,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
@Override
public Option stop(boolean success) {
try {
+ // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+ Map> internalAccumulators =
+ taskContext.internalMetricsToAccumulators();
+ if (internalAccumulators != null) {
+ internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+ .add(getPeakMemoryUsedBytes());
+ }
+
if (stopping) {
return Option.apply(null);
} else {
@@ -450,7 +482,7 @@ public Option stop(boolean success) {
if (sorter != null) {
// If sorter is non-null, then this implies that we called stop() in response to an error,
// so we need to clean up memory and spill files created by the sorter
- sorter.cleanupAfterError();
+ sorter.cleanupResources();
}
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
similarity index 59%
rename from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
rename to core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 198e0684f32f8..5f3a4fcf4d585 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -17,34 +17,48 @@
package org.apache.spark.unsafe.map;
-import java.lang.Override;
-import java.lang.UnsupportedOperationException;
+import javax.annotation.Nullable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import org.apache.spark.unsafe.*;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.bitset.BitSet;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
-import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
/**
* An append-only hash map where keys and values are contiguous regions of bytes.
- *
+ *
* This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
* which is guaranteed to exhaust the space.
- *
+ *
* The map can support up to 2^29 keys. If the key cardinality is higher than this, you should
* probably be using sorting instead of hashing for better cache locality.
- *
- * This class is not thread safe.
+ *
+ * The key and values under the hood are stored together, in the following format:
+ * Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
+ * Bytes 4 to 8: len(k)
+ * Bytes 8 to 8 + len(k): key data
+ * Bytes 8 + len(k) to 8 + len(k) + len(v): value data
+ *
+ * This means that the first four bytes store the entire record (key + value) length. This format
+ * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
+ * so we can pass records from this map directly into the sorter to sort records in place.
*/
public final class BytesToBytesMap {
+ private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
+
private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
@@ -54,7 +68,9 @@ public final class BytesToBytesMap {
*/
private static final int END_OF_PAGE_MARKER = -1;
- private final TaskMemoryManager memoryManager;
+ private final TaskMemoryManager taskMemoryManager;
+
+ private final ShuffleMemoryManager shuffleMemoryManager;
/**
* A linked list for tracking all allocated data pages so that we can free all of our memory.
@@ -92,7 +108,7 @@ public final class BytesToBytesMap {
* Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
* while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
*/
- private LongArray longArray;
+ @Nullable private LongArray longArray;
// TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
// and exploit word-alignment to use fewer bits to hold the address. This might let us store
// only one long per map entry, increasing the chance that this array will fit in cache at the
@@ -107,7 +123,7 @@ public final class BytesToBytesMap {
* A {@link BitSet} used to track location of the map where the key is set.
* Size of the bitset should be half of the size of the long array.
*/
- private BitSet bitset;
+ @Nullable private BitSet bitset;
private final double loadFactor;
@@ -120,7 +136,7 @@ public final class BytesToBytesMap {
/**
* Number of keys defined in the map.
*/
- private int size;
+ private int numElements;
/**
* The map will be expanded once the number of keys exceeds this threshold.
@@ -149,13 +165,17 @@ public final class BytesToBytesMap {
private long numHashCollisions = 0;
+ private long peakMemoryUsedBytes = 0L;
+
public BytesToBytesMap(
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
- this.memoryManager = memoryManager;
+ this.taskMemoryManager = taskMemoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
@@ -172,49 +192,77 @@ public BytesToBytesMap(
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
}
allocate(initialCapacity);
+
+ // Acquire a new page as soon as we construct the map to ensure that we have at least
+ // one page to work with. Otherwise, other operators in the same task may starve this
+ // map (SPARK-9747).
+ acquireNewPage();
}
public BytesToBytesMap(
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes) {
- this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+ this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
}
public BytesToBytesMap(
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
- this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics);
+ this(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ initialCapacity,
+ 0.70,
+ pageSizeBytes,
+ enablePerfMetrics);
}
/**
* Returns the number of keys defined in the map.
*/
- public int size() { return size; }
+ public int numElements() { return numElements; }
- private static final class BytesToBytesMapIterator implements Iterator {
+ public static final class BytesToBytesMapIterator implements Iterator {
private final int numRecords;
private final Iterator dataPagesIterator;
private final Location loc;
+ private MemoryBlock currentPage = null;
private int currentRecordNumber = 0;
private Object pageBaseObject;
private long offsetInPage;
- BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) {
+ // If this iterator destructive or not. When it is true, it frees each page as it moves onto
+ // next one.
+ private boolean destructive = false;
+ private BytesToBytesMap bmap;
+
+ private BytesToBytesMapIterator(
+ int numRecords, Iterator dataPagesIterator, Location loc,
+ boolean destructive, BytesToBytesMap bmap) {
this.numRecords = numRecords;
this.dataPagesIterator = dataPagesIterator;
this.loc = loc;
+ this.destructive = destructive;
+ this.bmap = bmap;
if (dataPagesIterator.hasNext()) {
advanceToNextPage();
}
}
private void advanceToNextPage() {
- final MemoryBlock currentPage = dataPagesIterator.next();
+ if (destructive && currentPage != null) {
+ dataPagesIterator.remove();
+ this.bmap.taskMemoryManager.freePage(currentPage);
+ this.bmap.shuffleMemoryManager.release(currentPage.size());
+ }
+ currentPage = dataPagesIterator.next();
pageBaseObject = currentPage.getBaseObject();
offsetInPage = currentPage.getBaseOffset();
}
@@ -226,13 +274,13 @@ public boolean hasNext() {
@Override
public Location next() {
- int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
- if (keyLength == END_OF_PAGE_MARKER) {
+ int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
+ if (totalLength == END_OF_PAGE_MARKER) {
advanceToNextPage();
- keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ totalLength = Platform.getInt(pageBaseObject, offsetInPage);
}
- loc.with(pageBaseObject, offsetInPage);
- offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+ loc.with(currentPage, offsetInPage);
+ offsetInPage += 4 + totalLength;
currentRecordNumber++;
return loc;
}
@@ -251,8 +299,22 @@ public void remove() {
* If any other lookups or operations are performed on this map while iterating over it, including
* `lookup()`, the behavior of the returned iterator is undefined.
*/
- public Iterator iterator() {
- return new BytesToBytesMapIterator(size, dataPages.iterator(), loc);
+ public BytesToBytesMapIterator iterator() {
+ return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this);
+ }
+
+ /**
+ * Returns a destructive iterator for iterating over the entries of this map. It frees each page
+ * as it moves onto next one. Notice: it is illegal to call any method on the map after
+ * `destructiveIterator()` has been called.
+ *
+ * For efficiency, all calls to `next()` will return the same {@link Location} object.
+ *
+ * If any other lookups or operations are performed on this map while iterating over it, including
+ * `lookup()`, the behavior of the returned iterator is undefined.
+ */
+ public BytesToBytesMapIterator destructiveIterator() {
+ return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this);
}
/**
@@ -265,6 +327,23 @@ public Location lookup(
Object keyBaseObject,
long keyBaseOffset,
int keyRowLengthBytes) {
+ safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc);
+ return loc;
+ }
+
+ /**
+ * Looks up a key, and saves the result in provided `loc`.
+ *
+ * This is a thread-safe version of `lookup`, could be used by multiple threads.
+ */
+ public void safeLookup(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyRowLengthBytes,
+ Location loc) {
+ assert(bitset != null);
+ assert(longArray != null);
+
if (enablePerfMetrics) {
numKeyLookups++;
}
@@ -277,7 +356,8 @@ public Location lookup(
}
if (!bitset.isSet(pos)) {
// This is a new key.
- return loc.with(pos, hashcode, false);
+ loc.with(pos, hashcode, false);
+ return;
} else {
long stored = longArray.get(pos * 2 + 1);
if ((int) (stored) == hashcode) {
@@ -295,7 +375,7 @@ public Location lookup(
keyRowLengthBytes
);
if (areEqual) {
- return loc;
+ return;
} else {
if (enablePerfMetrics) {
numHashCollisions++;
@@ -328,23 +408,33 @@ public final class Location {
private int keyLength;
private int valueLength;
+ /**
+ * Memory page containing the record. Only set if created by {@link BytesToBytesMap#iterator()}.
+ */
+ @Nullable private MemoryBlock memoryPage;
+
private void updateAddressesAndSizes(long fullKeyAddress) {
updateAddressesAndSizes(
- memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress));
+ taskMemoryManager.getPage(fullKeyAddress),
+ taskMemoryManager.getOffsetInPage(fullKeyAddress));
}
- private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
- long position = keyOffsetInPage;
- keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
- position += 8; // word used to store the key size
- keyMemoryLocation.setObjAndOffset(page, position);
- position += keyLength;
- valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
- position += 8; // word used to store the key size
- valueMemoryLocation.setObjAndOffset(page, position);
+ private void updateAddressesAndSizes(final Object page, final long offsetInPage) {
+ long position = offsetInPage;
+ final int totalLength = Platform.getInt(page, position);
+ position += 4;
+ keyLength = Platform.getInt(page, position);
+ position += 4;
+ valueLength = totalLength - keyLength - 4;
+
+ keyMemoryLocation.setObjAndOffset(page, position);
+
+ position += keyLength;
+ valueMemoryLocation.setObjAndOffset(page, position);
}
- Location with(int pos, int keyHashcode, boolean isDefined) {
+ private Location with(int pos, int keyHashcode, boolean isDefined) {
+ assert(longArray != null);
this.pos = pos;
this.isDefined = isDefined;
this.keyHashcode = keyHashcode;
@@ -355,12 +445,21 @@ Location with(int pos, int keyHashcode, boolean isDefined) {
return this;
}
- Location with(Object page, long keyOffsetInPage) {
+ private Location with(MemoryBlock page, long offsetInPage) {
this.isDefined = true;
- updateAddressesAndSizes(page, keyOffsetInPage);
+ this.memoryPage = page;
+ updateAddressesAndSizes(page.getBaseObject(), offsetInPage);
return this;
}
+ /**
+ * Returns the memory page that contains the current record.
+ * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
+ */
+ public MemoryBlock getMemoryPage() {
+ return this.memoryPage;
+ }
+
/**
* Returns true if the key is defined at this position, and false otherwise.
*/
@@ -411,7 +510,8 @@ public int getValueLength() {
/**
* Store a new key and value. This method may only be called once for a given key; if you want
* to update the value associated with a key, then you can directly manipulate the bytes stored
- * at the value address.
+ * at the value address. The return value indicates whether the put succeeded or whether it
+ * failed because additional memory could not be acquired.
*
* It is only valid to call this method immediately after calling `lookup()` using the same key.
*
@@ -428,14 +528,19 @@ public int getValueLength() {
*
* Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
* if (!loc.isDefined()) {
- * loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)
+ * if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+ * // handle failure to grow map (by spilling, for example)
+ * }
* }
*
*
* Unspecified behavior if the key is not defined.
*
+ *
+ * @return true if the put() was successful and false if the put() failed because memory could
+ * not be acquired.
*/
- public void putNewKey(
+ public boolean putNewKey(
Object keyBaseObject,
long keyBaseOffset,
int keyLengthBytes,
@@ -445,66 +550,128 @@ public void putNewKey(
assert (!isDefined) : "Can only set value once for a key";
assert (keyLengthBytes % 8 == 0);
assert (valueLengthBytes % 8 == 0);
- if (size == MAX_CAPACITY) {
+ assert(bitset != null);
+ assert(longArray != null);
+
+ if (numElements == MAX_CAPACITY) {
throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
}
+
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
- // (8 byte key length) (key) (8 byte value length) (value)
- final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
- assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker.
- size++;
- bitset.set(pos);
-
- // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
- // for the end-of-page marker).
- if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+ // (8 byte key length) (key) (value)
+ final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
+
+ // --- Figure out where to insert the new record ---------------------------------------------
+
+ final MemoryBlock dataPage;
+ final Object dataPageBaseObject;
+ final long dataPageInsertOffset;
+ boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
+ if (useOverflowPage) {
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryRequested = requiredSize + 8;
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryGranted != memoryRequested) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+ return false;
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
+ dataPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPageBaseObject = overflowPage.getBaseObject();
+ dataPageInsertOffset = overflowPage.getBaseOffset();
+ } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+ // The record can fit in a data page, but either we have not allocated any pages yet or
+ // the current page does not have enough space.
if (currentDataPage != null) {
// There wasn't enough space in the current page, so write an end-of-page marker:
final Object pageBaseObject = currentDataPage.getBaseObject();
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
- PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+ Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+ }
+ if (!acquireNewPage()) {
+ return false;
}
- MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes);
- dataPages.add(newPage);
- pageCursor = 0;
- currentDataPage = newPage;
+ dataPage = currentDataPage;
+ dataPageBaseObject = currentDataPage.getBaseObject();
+ dataPageInsertOffset = currentDataPage.getBaseOffset();
+ } else {
+ // There is enough space in the current data page.
+ dataPage = currentDataPage;
+ dataPageBaseObject = currentDataPage.getBaseObject();
+ dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
}
- // Compute all of our offsets up-front:
- final Object pageBaseObject = currentDataPage.getBaseObject();
- final long pageBaseOffset = currentDataPage.getBaseOffset();
- final long keySizeOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += 8; // word used to store the key size
- final long keyDataOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += keyLengthBytes;
- final long valueSizeOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += 8; // word used to store the value size
- final long valueDataOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += valueLengthBytes;
+ // --- Append the key and value data to the current data page --------------------------------
+ long insertCursor = dataPageInsertOffset;
+
+ // Compute all of our offsets up-front:
+ final long recordOffset = insertCursor;
+ insertCursor += 4;
+ final long keyLengthOffset = insertCursor;
+ insertCursor += 4;
+ final long keyDataOffsetInPage = insertCursor;
+ insertCursor += keyLengthBytes;
+ final long valueDataOffsetInPage = insertCursor;
+ insertCursor += valueLengthBytes; // word used to store the value size
+
+ Platform.putInt(dataPageBaseObject, recordOffset,
+ keyLengthBytes + valueLengthBytes + 4);
+ Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
// Copy the key
- PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes);
- PlatformDependent.copyMemory(
- keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+ Platform.copyMemory(
+ keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
// Copy the value
- PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
- PlatformDependent.copyMemory(
- valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes);
+ Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
+ valueDataOffsetInPage, valueLengthBytes);
+
+ // --- Update bookeeping data structures -----------------------------------------------------
- final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
- currentDataPage, keySizeOffsetInPage);
+ if (useOverflowPage) {
+ // Store the end-of-page marker at the end of the data page
+ Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
+ } else {
+ pageCursor += requiredSize;
+ }
+
+ numElements++;
+ bitset.set(pos);
+ final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
+ dataPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
isDefined = true;
- if (size > growthThreshold && longArray.size() < MAX_CAPACITY) {
+ if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
growAndRehash();
}
+ return true;
}
}
+ /**
+ * Acquire a new page from the {@link ShuffleMemoryManager}.
+ * @return whether there is enough space to allocate the new page.
+ */
+ private boolean acquireNewPage() {
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryGranted != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+ return false;
+ }
+ MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ dataPages.add(newPage);
+ pageCursor = 0;
+ currentDataPage = newPage;
+ return true;
+ }
+
/**
* Allocate new data structures for this map. When calling this outside of the constructor,
* make sure to keep references to the old data structures so that you can free them.
@@ -514,9 +681,9 @@ public void putNewKey(
private void allocate(int capacity) {
assert (capacity >= 0);
// The capacity needs to be divisible by 64 so that our bit set can be sized properly
- capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
+ capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
- longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2));
+ longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
this.growthThreshold = (int) (capacity * loadFactor);
@@ -527,32 +694,60 @@ private void allocate(int capacity) {
* Free all allocated memory associated with this map, including the storage for keys and values
* as well as the hash map array itself.
*
- * This method is idempotent.
+ * This method is idempotent and can be called multiple times.
*/
public void free() {
- if (longArray != null) {
- memoryManager.free(longArray.memoryBlock());
- longArray = null;
- }
- if (bitset != null) {
- // The bitset's heap memory isn't managed by a memory manager, so no need to free it here.
- bitset = null;
- }
+ updatePeakMemoryUsed();
+ longArray = null;
+ bitset = null;
Iterator dataPagesIterator = dataPages.iterator();
while (dataPagesIterator.hasNext()) {
- memoryManager.freePage(dataPagesIterator.next());
+ MemoryBlock dataPage = dataPagesIterator.next();
dataPagesIterator.remove();
+ taskMemoryManager.freePage(dataPage);
+ shuffleMemoryManager.release(dataPage.size());
}
assert(dataPages.isEmpty());
}
- /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+ public TaskMemoryManager getTaskMemoryManager() {
+ return taskMemoryManager;
+ }
+
+ public ShuffleMemoryManager getShuffleMemoryManager() {
+ return shuffleMemoryManager;
+ }
+
+ public long getPageSizeBytes() {
+ return pageSizeBytes;
+ }
+
+ /**
+ * Returns the total amount of memory, in bytes, consumed by this map's managed structures.
+ */
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
for (MemoryBlock dataPage : dataPages) {
totalDataPagesSize += dataPage.size();
}
- return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
+ return totalDataPagesSize +
+ ((bitset != null) ? bitset.memoryBlock().size() : 0L) +
+ ((longArray != null) ? longArray.memoryBlock().size() : 0L);
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getTotalMemoryConsumption();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
/**
@@ -565,7 +760,6 @@ public long getTimeSpentResizingNs() {
return timeSpentResizingNs;
}
-
/**
* Returns the average number of probes per key lookup.
*/
@@ -584,7 +778,7 @@ public long getNumHashCollisions() {
}
@VisibleForTesting
- int getNumDataPages() {
+ public int getNumDataPages() {
return dataPages.size();
}
@@ -593,6 +787,9 @@ int getNumDataPages() {
*/
@VisibleForTesting
void growAndRehash() {
+ assert(bitset != null);
+ assert(longArray != null);
+
long resizeStartTime = -1;
if (enablePerfMetrics) {
resizeStartTime = System.nanoTime();
@@ -628,16 +825,8 @@ void growAndRehash() {
}
}
- // Deallocate the old data structures.
- memoryManager.free(oldLongArray.memoryBlock());
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;
}
}
-
- /** Returns the next number greater or equal num that is power of 2. */
- private static long nextPowerOf2(long num) {
- final long highBit = Long.highestOneBit(num);
- return (highBit == num) ? num : highBit << 1;
- }
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
rename to core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 600aff7d15d8a..71b76d5ddfaa7 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -20,6 +20,7 @@
import com.google.common.primitives.UnsignedLongs;
import org.apache.spark.annotation.Private;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
import org.apache.spark.util.Utils;
@@ -28,9 +29,13 @@ public class PrefixComparators {
private PrefixComparators() {}
public static final StringPrefixComparator STRING = new StringPrefixComparator();
- public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator();
- public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
+ public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
+ public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
+ public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
+ public static final LongPrefixComparator LONG = new LongPrefixComparator();
+ public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
+ public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
public static final class StringPrefixComparator extends PrefixComparator {
@Override
@@ -38,36 +43,62 @@ public int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
- public long computePrefix(UTF8String value) {
+ public static long computePrefix(UTF8String value) {
return value == null ? 0L : value.getPrefix();
}
}
- /**
- * Prefix comparator for all integral types (boolean, byte, short, int, long).
- */
- public static final class IntegralPrefixComparator extends PrefixComparator {
+ public static final class StringPrefixComparatorDesc extends PrefixComparator {
@Override
- public int compare(long a, long b) {
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ public int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
}
-
- public final long NULL_PREFIX = Long.MIN_VALUE;
}
- public static final class FloatPrefixComparator extends PrefixComparator {
+ public static final class BinaryPrefixComparator extends PrefixComparator {
@Override
public int compare(long aPrefix, long bPrefix) {
- float a = Float.intBitsToFloat((int) aPrefix);
- float b = Float.intBitsToFloat((int) bPrefix);
- return Utils.nanSafeCompareFloats(a, b);
+ return UnsignedLongs.compare(aPrefix, bPrefix);
}
- public long computePrefix(float value) {
- return Float.floatToIntBits(value) & 0xffffffffL;
+ public static long computePrefix(byte[] bytes) {
+ if (bytes == null) {
+ return 0L;
+ } else {
+ /**
+ * TODO: If a wrapper for BinaryType is created (SPARK-8786),
+ * these codes below will be in the wrapper class.
+ */
+ final int minLen = Math.min(bytes.length, 8);
+ long p = 0;
+ for (int i = 0; i < minLen; ++i) {
+ p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i))
+ << (56 - 8 * i);
+ }
+ return p;
+ }
}
+ }
- public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY);
+ public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
+ }
+ }
+
+ public static final class LongPrefixComparator extends PrefixComparator {
+ @Override
+ public int compare(long a, long b) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
+ }
+
+ public static final class LongPrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
+ }
}
public static final class DoublePrefixComparator extends PrefixComparator {
@@ -78,10 +109,21 @@ public int compare(long aPrefix, long bPrefix) {
return Utils.nanSafeCompareDoubles(a, b);
}
- public long computePrefix(double value) {
+ public static long computePrefix(double value) {
return Double.doubleToLongBits(value);
}
+ }
+
+ public static final class DoublePrefixComparatorDesc extends PrefixComparator {
+ @Override
+ public int compare(long bPrefix, long aPrefix) {
+ double a = Double.longBitsToDouble(aPrefix);
+ double b = Double.longBitsToDouble(bPrefix);
+ return Utils.nanSafeCompareDoubles(a, b);
+ }
- public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY);
+ public static long computePrefix(double value) {
+ return Double.doubleToLongBits(value);
+ }
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 866e0b4151577..fc364e0a895b1 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,9 +17,12 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import javax.annotation.Nullable;
+
import scala.runtime.AbstractFunction0;
import scala.runtime.BoxedUnit;
@@ -27,12 +30,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
@@ -48,7 +51,7 @@ public final class UnsafeExternalSorter {
private final PrefixComparator prefixComparator;
private final RecordComparator recordComparator;
private final int initialSize;
- private final TaskMemoryManager memoryManager;
+ private final TaskMemoryManager taskMemoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
@@ -63,26 +66,58 @@ public final class UnsafeExternalSorter {
* this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
* itself).
*/
- private final LinkedList allocatedPages = new LinkedList();
+ private final LinkedList allocatedPages = new LinkedList<>();
+
+ private final LinkedList spillWriters = new LinkedList<>();
// These variables are reset after spilling:
- private UnsafeInMemorySorter sorter;
+ @Nullable private UnsafeInMemorySorter inMemSorter;
+ // Whether the in-mem sorter is created internally, or passed in from outside.
+ // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
+ private boolean isInMemSorterExternal = false;
private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
+ private long peakMemoryUsedBytes = 0;
- private final LinkedList spillWriters = new LinkedList<>();
+ public static UnsafeExternalSorter createWithExistingInMemorySorter(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ long pageSizeBytes,
+ UnsafeInMemorySorter inMemorySorter) throws IOException {
+ return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+ }
+
+ public static UnsafeExternalSorter create(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ long pageSizeBytes) throws IOException {
+ return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
+ }
- public UnsafeExternalSorter(
- TaskMemoryManager memoryManager,
+ private UnsafeExternalSorter(
+ TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- SparkConf conf) throws IOException {
- this.memoryManager = memoryManager;
+ long pageSizeBytes,
+ @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
+ this.taskMemoryManager = taskMemoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
@@ -90,9 +125,21 @@ public UnsafeExternalSorter(
this.prefixComparator = prefixComparator;
this.initialSize = initialSize;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
- this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
- this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
- initializeForWriting();
+ // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.fileBufferSizeBytes = 32 * 1024;
+ this.pageSizeBytes = pageSizeBytes;
+ this.writeMetrics = new ShuffleWriteMetrics();
+
+ if (existingInMemorySorter == null) {
+ initializeForWriting();
+ // Acquire a new page as soon as we construct the sorter to ensure that we have at
+ // least one page to work with. Otherwise, other operators in the same task may starve
+ // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
+ acquireNewPage();
+ } else {
+ this.isInMemSorterExternal = true;
+ this.inMemSorter = existingInMemorySorter;
+ }
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
@@ -100,7 +147,7 @@ public UnsafeExternalSorter(
taskContext.addOnCompleteCallback(new AbstractFunction0() {
@Override
public BoxedUnit apply() {
- freeMemory();
+ cleanupResources();
return null;
}
});
@@ -114,56 +161,90 @@ public BoxedUnit apply() {
*/
private void initializeForWriting() throws IOException {
this.writeMetrics = new ShuffleWriteMetrics();
- // TODO: move this sizing calculation logic into a static method of sorter:
- final long memoryRequested = initialSize * 8L * 2;
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
- if (memoryAcquired != memoryRequested) {
+ final long pointerArrayMemory =
+ UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize);
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory);
+ if (memoryAcquired != pointerArrayMemory) {
shuffleMemoryManager.release(memoryAcquired);
- throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory");
}
- this.sorter =
- new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize);
+ this.inMemSorter =
+ new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
+ this.isInMemSorterExternal = false;
}
/**
- * Sort and spill the current records in response to memory pressure.
+ * Marks the current page as no-more-space-available, and as a result, either allocate a
+ * new page or spill when we see the next record.
*/
@VisibleForTesting
+ public void closeCurrentPage() {
+ freeSpaceInCurrentPage = 0;
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
public void spill() throws IOException {
+ assert(inMemSorter != null);
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
spillWriters.size(),
spillWriters.size() > 1 ? " times" : " time");
- final UnsafeSorterSpillWriter spillWriter =
- new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
- sorter.numRecords());
- spillWriters.add(spillWriter);
- final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
- while (sortedRecords.hasNext()) {
- sortedRecords.loadNext();
- final Object baseObject = sortedRecords.getBaseObject();
- final long baseOffset = sortedRecords.getBaseOffset();
- final int recordLength = sortedRecords.getRecordLength();
- spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ // We only write out contents of the inMemSorter if it is not empty.
+ if (inMemSorter.numRecords() > 0) {
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
+ inMemSorter.numRecords());
+ spillWriters.add(spillWriter);
+ final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
+ while (sortedRecords.hasNext()) {
+ sortedRecords.loadNext();
+ final Object baseObject = sortedRecords.getBaseObject();
+ final long baseOffset = sortedRecords.getBaseOffset();
+ final int recordLength = sortedRecords.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
+ }
+ spillWriter.close();
}
- spillWriter.close();
- final long sorterMemoryUsage = sorter.getMemoryUsage();
- sorter = null;
- shuffleMemoryManager.release(sorterMemoryUsage);
+
final long spillSize = freeMemory();
+ // Note that this is more-or-less going to be a multiple of the page size, so wasted space in
+ // pages will currently be counted as memory spilled even though that space isn't actually
+ // written to disk. This also counts the space needed to store the sorter's pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
initializeForWriting();
}
+ /**
+ * Return the total memory usage of this sorter, including the data pages and the sorter's pointer
+ * array.
+ */
private long getMemoryUsage() {
long totalPageSize = 0;
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return sorter.getMemoryUsage() + totalPageSize;
+ return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+ }
+
+ private void updatePeakMemoryUsed() {
+ long mem = getMemoryUsage();
+ if (mem > peakMemoryUsedBytes) {
+ peakMemoryUsedBytes = mem;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ updatePeakMemoryUsed();
+ return peakMemoryUsedBytes;
}
@VisibleForTesting
@@ -171,13 +252,27 @@ public int getNumberOfAllocatedPages() {
return allocatedPages.size();
}
- public long freeMemory() {
+ /**
+ * Free this sorter's in-memory data structures, including its data pages and pointer array.
+ *
+ * @return the number of bytes freed.
+ */
+ private long freeMemory() {
+ updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
- memoryManager.freePage(block);
+ taskMemoryManager.freePage(block);
shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
+ if (inMemSorter != null) {
+ if (!isInMemSorterExternal) {
+ long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ memoryFreed += sorterMemoryUsage;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ }
+ inMemSorter = null;
+ }
allocatedPages.clear();
currentPage = null;
currentPagePosition = -1;
@@ -186,44 +281,61 @@ public long freeMemory() {
}
/**
- * Checks whether there is enough space to insert a new record into the sorter.
- *
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
- * the record size.
+ * Deletes any spill files created by this sorter.
+ */
+ private void deleteSpillFiles() {
+ for (UnsafeSorterSpillWriter spill : spillWriters) {
+ File file = spill.getFile();
+ if (file != null && file.exists()) {
+ if (!file.delete()) {
+ logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+ }
+ }
+ }
+ }
- * @return true if the record can be inserted without requiring more allocations, false otherwise.
+ /**
+ * Frees this sorter's in-memory data structures and cleans up its spill files.
*/
- private boolean haveSpaceForRecord(int requiredSpace) {
- assert (requiredSpace > 0);
- return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ public void cleanupResources() {
+ deleteSpillFiles();
+ freeMemory();
}
/**
- * Allocates more memory in order to insert an additional record. This will request additional
- * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
- * obtained.
- *
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
- * the record size.
+ * Checks whether there is enough space to insert an additional record in to the sort pointer
+ * array and grows the array if additional space is required. If the required space cannot be
+ * obtained, then the in-memory data will be spilled to disk.
*/
- private void allocateSpaceForRecord(int requiredSpace) throws IOException {
- // TODO: merge these steps to first calculate total memory requirements for this insert,
- // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
- // data page.
- if (!sorter.hasSpaceForAnotherRecord()) {
+ private void growPointerArrayIfNecessary() throws IOException {
+ assert(inMemSorter != null);
+ if (!inMemSorter.hasSpaceForAnotherRecord()) {
logger.debug("Attempting to expand sort pointer array");
- final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
if (memoryAcquired < memoryToGrowPointerArray) {
shuffleMemoryManager.release(memoryAcquired);
spill();
} else {
- sorter.expandPointerArray();
+ inMemSorter.expandPointerArray();
shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
}
}
+ }
+ /**
+ * Allocates more memory in order to insert an additional record. This will request additional
+ * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+ * obtained.
+ *
+ * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * the record size. This must be less than or equal to the page size (records
+ * that exceed the page size are handled via a different code path which uses
+ * special overflow pages).
+ */
+ private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
+ assert (requiredSpace <= pageSizeBytes);
if (requiredSpace > freeSpaceInCurrentPage) {
logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
freeSpaceInCurrentPage);
@@ -234,24 +346,34 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
pageSizeBytes + ")");
} else {
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquired < pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquired);
- spill();
- final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquiredAfterSpilling != pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
- throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
- }
- }
- currentPage = memoryManager.allocatePage(pageSizeBytes);
- currentPagePosition = currentPage.getBaseOffset();
- freeSpaceInCurrentPage = pageSizeBytes;
- allocatedPages.add(currentPage);
+ acquireNewPage();
}
}
}
+ /**
+ * Acquire a new page from the {@link ShuffleMemoryManager}.
+ *
+ * If there is not enough space to allocate the new page, spill all existing ones
+ * and try again. If there is still not enough space, report error to the caller.
+ */
+ private void acquireNewPage() throws IOException {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquired < pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquiredAfterSpilling != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
+ }
+ }
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = pageSizeBytes;
+ allocatedPages.add(currentPage);
+ }
+
/**
* Write a record to the sorter.
*/
@@ -260,30 +382,134 @@ public void insertRecord(
long recordBaseOffset,
int lengthInBytes,
long prefix) throws IOException {
+
+ growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
- if (!haveSpaceForRecord(totalSpaceRequired)) {
- allocateSpaceForRecord(totalSpaceRequired);
+
+ // --- Figure out where to insert the new record ----------------------------------------------
+
+ final MemoryBlock dataPage;
+ long dataPagePosition;
+ boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+ if (useOverflowPage) {
+ long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGranted != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGranted);
+ spill();
+ final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGrantedAfterSpill != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+ }
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ allocatedPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPagePosition = overflowPage.getBaseOffset();
+ } else {
+ // The record is small enough to fit in a regular data page, but the current page might not
+ // have enough space to hold it (or no pages have been allocated yet).
+ acquireNewPageIfNecessary(totalSpaceRequired);
+ dataPage = currentPage;
+ dataPagePosition = currentPagePosition;
+ // Update bookkeeping information
+ freeSpaceInCurrentPage -= totalSpaceRequired;
+ currentPagePosition += totalSpaceRequired;
}
+ final Object dataPageBaseObject = dataPage.getBaseObject();
+
+ // --- Insert the record ----------------------------------------------------------------------
final long recordAddress =
- memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
- final Object dataPageBaseObject = currentPage.getBaseObject();
- PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
- currentPagePosition += 4;
- PlatformDependent.copyMemory(
- recordBaseObject,
- recordBaseOffset,
- dataPageBaseObject,
- currentPagePosition,
- lengthInBytes);
- currentPagePosition += lengthInBytes;
- freeSpaceInCurrentPage -= totalSpaceRequired;
- sorter.insertRecord(recordAddress, prefix);
+ taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+ Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
+ dataPagePosition += 4;
+ Platform.copyMemory(
+ recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+ assert(inMemSorter != null);
+ inMemSorter.insertRecord(recordAddress, prefix);
}
+ /**
+ * Write a key-value record to the sorter. The key and value will be put together in-memory,
+ * using the following format:
+ *
+ * record length (4 bytes), key length (4 bytes), key data, value data
+ *
+ * record length = key length + value length + 4
+ */
+ public void insertKVRecord(
+ Object keyBaseObj, long keyOffset, int keyLen,
+ Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
+
+ growPointerArrayIfNecessary();
+ final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
+
+ // --- Figure out where to insert the new record ----------------------------------------------
+
+ final MemoryBlock dataPage;
+ long dataPagePosition;
+ boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+ if (useOverflowPage) {
+ long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGranted != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGranted);
+ spill();
+ final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+ if (memoryGrantedAfterSpill != overflowPageSize) {
+ shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+ }
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ allocatedPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPagePosition = overflowPage.getBaseOffset();
+ } else {
+ // The record is small enough to fit in a regular data page, but the current page might not
+ // have enough space to hold it (or no pages have been allocated yet).
+ acquireNewPageIfNecessary(totalSpaceRequired);
+ dataPage = currentPage;
+ dataPagePosition = currentPagePosition;
+ // Update bookkeeping information
+ freeSpaceInCurrentPage -= totalSpaceRequired;
+ currentPagePosition += totalSpaceRequired;
+ }
+ final Object dataPageBaseObject = dataPage.getBaseObject();
+
+ // --- Insert the record ----------------------------------------------------------------------
+
+ final long recordAddress =
+ taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+ Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4);
+ dataPagePosition += 4;
+
+ Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen);
+ dataPagePosition += 4;
+
+ Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen);
+ dataPagePosition += keyLen;
+
+ Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen);
+
+ assert(inMemSorter != null);
+ inMemSorter.insertRecord(recordAddress, prefix);
+ }
+
+ /**
+ * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
+ * after consuming this iterator.
+ */
public UnsafeSorterIterator getSortedIterator() throws IOException {
- final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+ assert(inMemSorter != null);
+ final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
return inMemoryIterator;
@@ -291,12 +517,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
- spillMerger.addSpill(spillWriter.getReader(blockManager));
+ spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
}
spillWriters.clear();
- if (inMemoryIterator.hasNext()) {
- spillMerger.addSpill(inMemoryIterator);
- }
+ spillMerger.addSpillIfNotEmpty(inMemoryIterator);
+
return spillMerger.getSortedIterator();
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index fc34ad9cff369..f7787e1019c2b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,7 +19,7 @@
import java.util.Comparator;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.Sorter;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
@@ -100,6 +100,10 @@ public long getMemoryUsage() {
return pointerArray.length * 8L;
}
+ static long getMemoryRequirementsForPointerArray(long numEntries) {
+ return numEntries * 2L * 8L;
+ }
+
public boolean hasSpaceForAnotherRecord() {
return pointerArrayInsertPosition + 2 < pointerArray.length;
}
@@ -129,7 +133,7 @@ public void insertRecord(long recordPointer, long keyPrefix) {
pointerArrayInsertPosition++;
}
- private static final class SortedIterator extends UnsafeSorterIterator {
+ public static final class SortedIterator extends UnsafeSorterIterator {
private final TaskMemoryManager memoryManager;
private final int sortBufferInsertPosition;
@@ -140,7 +144,7 @@ private static final class SortedIterator extends UnsafeSorterIterator {
private long keyPrefix;
private int recordLength;
- SortedIterator(
+ private SortedIterator(
TaskMemoryManager memoryManager,
int sortBufferInsertPosition,
long[] sortBuffer) {
@@ -160,7 +164,7 @@ public void loadNext() {
final long recordPointer = sortBuffer[position];
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
- recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
+ recordLength = Platform.getInt(baseObject, baseOffset - 4);
keyPrefix = sortBuffer[position + 1];
position += 2;
}
@@ -182,7 +186,7 @@ public void loadNext() {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
- public UnsafeSorterIterator getSortedIterator() {
+ public SortedIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 8272c2a5be0d1..3874a9f9cbdb6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -47,11 +47,19 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
priorityQueue = new PriorityQueue(numSpills, comparator);
}
- public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
+ /**
+ * Add an UnsafeSorterIterator to this merger
+ */
+ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException {
if (spillReader.hasNext()) {
+ // We only add the spillReader to the priorityQueue if it is not empty. We do this to
+ // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
+ // does not return wrong result because hasNext will returns true
+ // at least priorityQueue.size() times. If we allow n spillReaders in the
+ // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
spillReader.loadNext();
+ priorityQueue.add(spillReader);
}
- priorityQueue.add(spillReader);
}
public UnsafeSorterIterator getSortedIterator() throws IOException {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 29e9e0f30f934..4989b05d63e23 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -23,7 +23,7 @@
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
/**
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
@@ -31,6 +31,7 @@
*/
final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+ private final File file;
private InputStream in;
private DataInputStream din;
@@ -41,13 +42,14 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
- private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
+ private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
public UnsafeSorterSpillReader(
BlockManager blockManager,
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
+ this.file = file;
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
this.in = blockManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
@@ -71,6 +73,7 @@ public void loadNext() throws IOException {
numRecordsRemaining--;
if (numRecordsRemaining == 0) {
in.close();
+ file.delete();
in = null;
din = null;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 71eed29563d4a..e59a84ff8d118 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -28,7 +28,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempLocalBlockId;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
/**
* Spills a list of sorted records to disk. Spill files have the following format:
@@ -117,11 +117,11 @@ public void write(
long recordReadPosition = baseOffset;
while (dataRemaining > 0) {
final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
- PlatformDependent.copyMemory(
+ Platform.copyMemory(
baseObject,
recordReadPosition,
writeBuffer,
- PlatformDependent.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+ Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
toTransfer);
writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
recordReadPosition += toTransfer;
@@ -140,6 +140,10 @@ public void close() throws IOException {
writeBuffer = null;
}
+ public File getFile() {
+ return file;
+ }
+
public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
return new UnsafeSorterSpillReader(blockManager, file, blockId);
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index 4a893bc0189aa..83dbea40b63f3 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -109,13 +109,13 @@ function toggleDagViz(forJob) {
}
$(function (){
- if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") {
+ if ($("#stage-dag-viz").length &&
+ window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem(expandDagVizArrowKey(false), "false");
toggleDagViz(false);
- }
-
- if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") {
+ } else if ($("#job-dag-viz").length &&
+ window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") {
// Set it to false so that the click function can revert it
window.localStorage.setItem(expandDagVizArrowKey(true), "false");
toggleDagViz(true);
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index b1cef47042247..04f3070d25b4a 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -207,7 +207,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote,
-.serialization_time, .getting_result_time {
+.serialization_time, .getting_result_time, .peak_execution_memory {
display: none;
}
@@ -224,3 +224,11 @@ span.additional-metric-title {
a.expandbutton {
cursor: pointer;
}
+
+.executor-thread {
+ background: #E6E6E6;
+}
+
+.non-executor-thread {
+ background: #FAFAFA;
+}
\ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index eb75f26718e19..c39c8667d013e 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -152,8 +152,15 @@ class Accumulable[R, T] private[spark] (
in.defaultReadObject()
value_ = zero
deserialized = true
+ // Automatically register the accumulator when it is deserialized with the task closure.
+ //
+ // Note internal accumulators sent with task are deserialized before the TaskContext is created
+ // and are registered in the TaskContext constructor. Other internal accumulators, such SQL
+ // metrics, still need to register here.
val taskContext = TaskContext.get()
- taskContext.registerAccumulator(this)
+ if (taskContext != null) {
+ taskContext.registerAccumulator(this)
+ }
}
override def toString: String = if (value_ == null) "null" else value_.toString
@@ -248,10 +255,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
-class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
- extends Accumulable[T, T](initialValue, param, name) {
+class Accumulator[T] private[spark] (
+ @transient private[spark] val initialValue: T,
+ param: AccumulatorParam[T],
+ name: Option[String],
+ internal: Boolean)
+ extends Accumulable[T, T](initialValue, param, name, internal) {
+
+ def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = {
+ this(initialValue, param, name, false)
+ }
- def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
+ def this(initialValue: T, param: AccumulatorParam[T]) = {
+ this(initialValue, param, None, false)
+ }
}
/**
@@ -342,3 +359,41 @@ private[spark] object Accumulators extends Logging {
}
}
+
+private[spark] object InternalAccumulator {
+ val PEAK_EXECUTION_MEMORY = "peakExecutionMemory"
+ val TEST_ACCUMULATOR = "testAccumulator"
+
+ // For testing only.
+ // This needs to be a def since we don't want to reuse the same accumulator across stages.
+ private def maybeTestAccumulator: Option[Accumulator[Long]] = {
+ if (sys.props.contains("spark.testing")) {
+ Some(new Accumulator(
+ 0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true))
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Accumulators for tracking internal metrics.
+ *
+ * These accumulators are created with the stage such that all tasks in the stage will
+ * add to the same set of accumulators. We do this to report the distribution of accumulator
+ * values across all tasks within each stage.
+ */
+ def create(sc: SparkContext): Seq[Accumulator[Long]] = {
+ val internalAccumulators = Seq(
+ // Execution memory refers to the memory used by internal data structures created
+ // during shuffles, aggregations and joins. The value of this accumulator should be
+ // approximately the sum of the peak sizes across all such data structures created
+ // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort.
+ new Accumulator(
+ 0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true)
+ ) ++ maybeTestAccumulator.toSeq
+ internalAccumulators.foreach { accumulator =>
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator))
+ }
+ internalAccumulators
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ceeb58075d345..289aab9bd9e51 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -58,12 +58,7 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
- // Update task metrics if context is not null
- // TODO: Make context non optional in a future release
- Option(context).foreach { c =>
- c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
- c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
- }
+ updateMetrics(context, combiners)
combiners.iterator
}
}
@@ -89,13 +84,18 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
- // Update task metrics if context is not null
- // TODO: Make context non-optional in a future release
- Option(context).foreach { c =>
- c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
- c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
- }
+ updateMetrics(context, combiners)
combiners.iterator
}
}
+
+ /** Update task metrics after populating the external map. */
+ private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = {
+ Option(context).foreach { c =>
+ c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+ c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+ c.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 37198d887b07b..a14a55ec352d3 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -17,35 +17,13 @@
package org.apache.spark
-import java.lang.ref.{ReferenceQueue, WeakReference}
-
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
-
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.{RDDCheckpointData, RDD}
+import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
import org.apache.spark.util.Utils
+import org.apache.spark.util.cleanup.{ CleanAccum, CleanBroadcast, CleanCheckpoint }
+import org.apache.spark.util.cleanup.{ CleanRDD, CleanShuffle, CleanupTask }
-/**
- * Classes that represent cleaning tasks.
- */
-private sealed trait CleanupTask
-private case class CleanRDD(rddId: Int) extends CleanupTask
-private case class CleanShuffle(shuffleId: Int) extends CleanupTask
-private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
-private case class CleanAccum(accId: Long) extends CleanupTask
-private case class CleanCheckpoint(rddId: Int) extends CleanupTask
-
-/**
- * A WeakReference associated with a CleanupTask.
- *
- * When the referent object becomes only weakly reachable, the corresponding
- * CleanupTaskWeakReference is automatically added to the given reference queue.
- */
-private class CleanupTaskWeakReference(
- val task: CleanupTask,
- referent: AnyRef,
- referenceQueue: ReferenceQueue[AnyRef])
- extends WeakReference(referent, referenceQueue)
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
/**
* An asynchronous cleaner for RDD, shuffle, and broadcast state.
@@ -54,18 +32,11 @@ private class CleanupTaskWeakReference(
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
-private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
-
- private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
- with SynchronizedBuffer[CleanupTaskWeakReference]
-
- private val referenceQueue = new ReferenceQueue[AnyRef]
+private[spark] class ContextCleaner(sc: SparkContext) extends WeakReferenceCleaner {
private val listeners = new ArrayBuffer[CleanerListener]
with SynchronizedBuffer[CleanerListener]
- private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
-
/**
* Whether the cleaning thread will block on cleanup tasks (other than shuffle, which
* is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter).
@@ -92,35 +63,11 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private val blockOnShuffleCleanupTasks = sc.conf.getBoolean(
"spark.cleaner.referenceTracking.blocking.shuffle", false)
- @volatile private var stopped = false
-
/** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener): Unit = {
listeners += listener
}
- /** Start the cleaner. */
- def start(): Unit = {
- cleaningThread.setDaemon(true)
- cleaningThread.setName("Spark Context Cleaner")
- cleaningThread.start()
- }
-
- /**
- * Stop the cleaning thread and wait until the thread has finished running its current task.
- */
- def stop(): Unit = {
- stopped = true
- // Interrupt the cleaning thread, but wait until the current task has finished before
- // doing so. This guards against the race condition where a cleaning thread may
- // potentially clean similarly named variables created by a different SparkContext,
- // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132).
- synchronized {
- cleaningThread.interrupt()
- }
- cleaningThread.join()
- }
-
/** Register a RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]): Unit = {
registerForCleanup(rdd, CleanRDD(rdd.id))
@@ -145,43 +92,30 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(rdd, CleanCheckpoint(parentId))
}
- /** Register an object for cleanup. */
- private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
- referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
+ /** Keep cleaning RDD, shuffle, and broadcast state. */
+ override protected def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) {
+ super.keepCleaning()
}
- /** Keep cleaning RDD, shuffle, and broadcast state. */
- private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) {
- while (!stopped) {
- try {
- val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
- .map(_.asInstanceOf[CleanupTaskWeakReference])
- // Synchronize here to avoid being interrupted on stop()
- synchronized {
- reference.map(_.task).foreach { task =>
- logDebug("Got cleaning task " + task)
- referenceBuffer -= reference.get
- task match {
- case CleanRDD(rddId) =>
- doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
- case CleanShuffle(shuffleId) =>
- doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
- case CleanBroadcast(broadcastId) =>
- doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
- case CleanAccum(accId) =>
- doCleanupAccum(accId, blocking = blockOnCleanupTasks)
- case CleanCheckpoint(rddId) =>
- doCleanCheckpoint(rddId)
- }
- }
- }
- } catch {
- case ie: InterruptedException if stopped => // ignore
- case e: Exception => logError("Error in cleaning thread", e)
- }
+ protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = {
+ task match {
+ case CleanRDD(rddId) =>
+ doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
+ case CleanShuffle(shuffleId) =>
+ doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
+ case CleanBroadcast(broadcastId) =>
+ doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
+ case CleanAccum(accId) =>
+ doCleanupAccum(accId, blocking = blockOnCleanupTasks)
+ case CleanCheckpoint(rddId) =>
+ doCleanCheckpoint(rddId)
+ case unknown =>
+ logWarning(s"Got a cleanup task $unknown that cannot be handled by ContextCleaner,")
}
}
+ protected def cleanupThreadName(): String = "Context Cleaner"
+
/** Perform RDD cleanup. */
def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = {
try {
@@ -231,11 +165,14 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}
- /** Perform checkpoint cleanup. */
+ /**
+ * Clean up checkpoint files written to a reliable storage.
+ * Locally checkpointed files are cleaned up separately through RDD cleanups.
+ */
def doCleanCheckpoint(rddId: Int): Unit = {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
- RDDCheckpointData.clearRDDCheckpointData(sc, rddId)
+ ReliableRDDCheckpointData.cleanCheckpoint(sc, rddId)
listeners.foreach(_.checkpointCleaned(rddId))
logInfo("Cleaned rdd checkpoint data " + rddId)
}
@@ -249,10 +186,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
}
-private object ContextCleaner {
- private val REF_QUEUE_POLL_TIMEOUT = 100
-}
-
/**
* Listener class used for testing when any item has been cleaned by the Cleaner class.
*/
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 1877aaf2cac55..b93536e6536e2 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -599,14 +599,8 @@ private[spark] class ExecutorAllocationManager(
// If this is the last pending task, mark the scheduler queue as empty
stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
- val numTasksScheduled = stageIdToTaskIndices(stageId).size
- val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
- if (numTasksScheduled == numTasksTotal) {
- // No more pending tasks for this stage
- stageIdToNumTasks -= stageId
- if (stageIdToNumTasks.isEmpty) {
- allocationManager.onSchedulerQueueEmpty()
- }
+ if (totalPendingTasks() == 0) {
+ allocationManager.onSchedulerQueueEmpty()
}
// Mark the executor on which this task is scheduled as busy
@@ -618,6 +612,8 @@ private[spark] class ExecutorAllocationManager(
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val executorId = taskEnd.taskInfo.executorId
val taskId = taskEnd.taskInfo.taskId
+ val taskIndex = taskEnd.taskInfo.index
+ val stageId = taskEnd.stageId
allocationManager.synchronized {
numRunningTasks -= 1
// If the executor is no longer running any scheduled tasks, mark it as idle
@@ -628,6 +624,16 @@ private[spark] class ExecutorAllocationManager(
allocationManager.onExecutorIdle(executorId)
}
}
+
+ // If the task failed, we expect it to be resubmitted later. To ensure we have
+ // enough resources to run the resubmitted task, we need to mark the scheduler
+ // as backlogged again if it's not already marked as such (SPARK-8366)
+ if (taskEnd.reason != Success) {
+ if (totalPendingTasks() == 0) {
+ allocationManager.onSchedulerBacklogged()
+ }
+ stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) }
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala
new file mode 100644
index 0000000000000..716f0906e9fc3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark
+
+import java.io.File
+
+import org.apache.spark.util.cleanup.{CleanupTask, CleanExternalList}
+import org.apache.spark.util.collection.ExternalList
+
+/**
+ * Asynchronous cleaner for objects created on the Executor. So far
+ * only supports cleaning up ExternalList objects. Equivalent to ContextCleaner
+ * but for objects on the Executor heap.
+ */
+private[spark] class ExecutorCleaner extends WeakReferenceCleaner {
+
+ def registerExternalListForCleanup(list: ExternalList[_]): Unit = {
+ registerForCleanup(list, CleanExternalList(list.getBackingFileLocations()))
+ }
+
+ def doCleanExternalList(paths: Iterable[String]): Unit = {
+ paths.map(path => new File(path)).foreach(f => {
+ if (f.exists()) {
+ val isDeleted = f.delete()
+ if (!isDeleted) {
+ logWarning(s"Failed to delete ${f.getAbsolutePath} backing ExternalList")
+ }
+ }
+ })
+ }
+
+ override protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = {
+ task match {
+ case CleanExternalList(paths) => doCleanExternalList(paths)
+ case unknown => logWarning(s"Got cleanup task that cannot be" +
+ s" handled by ExecutorCleaner: $unknown")
+ }
+ }
+
+ override protected def cleanupThreadName(): String = "Executor Cleaner"
+}
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 43dd4a170731d..ee60d697d8799 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -177,16 +177,14 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms")
scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
s"timed out after ${now - lastSeenMs} ms"))
- if (sc.supportDynamicAllocation) {
// Asynchronously kill the executor to avoid blocking the current thread
- killExecutorThread.submit(new Runnable {
- override def run(): Unit = Utils.tryLogNonFatalError {
- // Note: we want to get an executor back after expiring this one,
- // so do not simply call `sc.killExecutor` here (SPARK-8119)
- sc.killAndReplaceExecutor(executorId)
- }
- })
- }
+ killExecutorThread.submit(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ // Note: we want to get an executor back after expiring this one,
+ // so do not simply call `sc.killExecutor` here (SPARK-8119)
+ sc.killAndReplaceExecutor(executorId)
+ }
+ })
executorLastSeen.remove(executorId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 4161792976c7b..b344b5e173d67 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -249,6 +249,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
Utils.byteStringAsBytes(get(key, defaultValue))
}
+ /**
+ * Get a size parameter as bytes, falling back to a default if not set.
+ */
+ def getSizeAsBytes(key: String, defaultValue: Long): Long = {
+ Utils.byteStringAsBytes(get(key, defaultValue + "B"))
+ }
+
/**
* Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no
* suffix is provided then Kibibytes are assumed.
@@ -382,6 +389,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
val driverOptsKey = "spark.driver.extraJavaOptions"
val driverClassPathKey = "spark.driver.extraClassPath"
val driverLibraryPathKey = "spark.driver.extraLibraryPath"
+ val sparkExecutorInstances = "spark.executor.instances"
// Used by Yarn in 1.1 and before
sys.props.get("spark.driver.libraryPath").foreach { value =>
@@ -469,6 +477,24 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
}
}
+
+ if (!contains(sparkExecutorInstances)) {
+ sys.env.get("SPARK_WORKER_INSTANCES").foreach { value =>
+ val warning =
+ s"""
+ |SPARK_WORKER_INSTANCES was detected (set to '$value').
+ |This is deprecated in Spark 1.0+.
+ |
+ |Please instead use:
+ | - ./spark-submit with --num-executors to specify the number of executors
+ | - Or set SPARK_EXECUTOR_INSTANCES
+ | - spark.executor.instances to configure the number of instances in the spark config.
+ """.stripMargin
+ logWarning(warning)
+
+ set("spark.executor.instances", value)
+ }
+ }
}
/**
@@ -548,7 +574,9 @@ private[spark] object SparkConf extends Logging {
"spark.rpc.askTimeout" -> Seq(
AlternateConfig("spark.akka.askTimeout", "1.4")),
"spark.rpc.lookupTimeout" -> Seq(
- AlternateConfig("spark.akka.lookupTimeout", "1.4"))
+ AlternateConfig("spark.akka.lookupTimeout", "1.4")),
+ "spark.streaming.fileStream.minRememberDuration" -> Seq(
+ AlternateConfig("spark.streaming.minRememberDuration", "1.5"))
)
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index ac6ac6c216767..2e01a9a18c784 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -528,11 +528,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
// Optionally scale number of executors dynamically based on workload. Exposed for testing.
- val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false)
+ val dynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(_conf)
+ if (!dynamicAllocationEnabled && _conf.getBoolean("spark.dynamicAllocation.enabled", false)) {
+ logInfo("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.")
+ }
+
_executorAllocationManager =
if (dynamicAllocationEnabled) {
- assert(supportDynamicAllocation,
- "Dynamic allocation of executors is currently only supported in YARN and Mesos mode")
Some(new ExecutorAllocationManager(this, listenerBus, _conf))
} else {
None
@@ -561,7 +563,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// Make sure the context is stopped if the user forgets about it. This avoids leaving
// unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM
// is killed, though.
- _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () =>
+ _shutdownHookRef = ShutdownHookManager.addShutdownHook(
+ ShutdownHookManager.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () =>
logInfo("Invoking stop() from shutdown hook")
stop()
}
@@ -631,7 +634,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* [[org.apache.spark.SparkContext.setLocalProperty]].
*/
def getLocalProperty(key: String): String =
- Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
+ Option(localProperties.get).map(_.getProperty(key)).orNull
/** Set a human readable description of the current job. */
def setJobDescription(value: String) {
@@ -868,7 +871,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* }}}
*
* Do
- * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ * `val rdd = sparkContext.binaryFiles("hdfs://a-hdfs-path")`,
*
* then `rdd` contains
* {{{
@@ -1194,7 +1197,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
protected[spark] def checkpointFile[T: ClassTag](path: String): RDD[T] = withScope {
- new CheckpointRDD[T](this, path)
+ new ReliableCheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
@@ -1361,17 +1364,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
postEnvironmentUpdate()
}
- /**
- * Return whether dynamically adjusting the amount of resources allocated to
- * this application is supported. This is currently only available for YARN
- * and Mesos coarse-grained mode.
- */
- private[spark] def supportDynamicAllocation: Boolean = {
- (master.contains("yarn")
- || master.contains("mesos")
- || _conf.getBoolean("spark.dynamicAllocation.testing", false))
- }
-
/**
* :: DeveloperApi ::
* Register a listener to receive up-calls from events that happen during execution.
@@ -1400,8 +1392,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
localityAwareTasks: Int,
hostToLocalTaskCount: scala.collection.immutable.Map[String, Int]
): Boolean = {
- assert(supportDynamicAllocation,
- "Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount)
@@ -1414,12 +1404,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request an additional number of executors from the cluster manager.
- * This is currently only supported in YARN mode. Return whether the request is received.
+ * @return whether the request is received.
*/
@DeveloperApi
override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
- assert(supportDynamicAllocation,
- "Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestExecutors(numAdditionalExecutors)
@@ -1438,12 +1426,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* through this method with new ones, it should follow up explicitly with a call to
* {{SparkContext#requestExecutors}}.
*
- * This is currently only supported in YARN mode. Return whether the request is received.
+ * @return whether the request is received.
*/
@DeveloperApi
override def killExecutors(executorIds: Seq[String]): Boolean = {
- assert(supportDynamicAllocation,
- "Killing executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(executorIds)
@@ -1462,7 +1448,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* through this method with a new one, it should follow up explicitly with a call to
* {{SparkContext#requestExecutors}}.
*
- * This is currently only supported in YARN mode. Return whether the request is received.
+ * @return whether the request is received.
*/
@DeveloperApi
override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)
@@ -1479,7 +1465,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* can steal the window of opportunity and acquire this application's resources in the
* mean time.
*
- * This is currently only supported in YARN mode. Return whether the request is received.
+ * @return whether the request is received.
*/
private[spark] def killAndReplaceExecutor(executorId: String): Boolean = {
schedulerBackend match {
@@ -1686,36 +1672,60 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
return
}
if (_shutdownHookRef != null) {
- Utils.removeShutdownHook(_shutdownHookRef)
+ ShutdownHookManager.removeShutdownHook(_shutdownHookRef)
}
- postApplicationEnd()
- _ui.foreach(_.stop())
+ Utils.tryLogNonFatalError {
+ postApplicationEnd()
+ }
+ Utils.tryLogNonFatalError {
+ _ui.foreach(_.stop())
+ }
if (env != null) {
- env.metricsSystem.report()
+ Utils.tryLogNonFatalError {
+ env.metricsSystem.report()
+ }
}
if (metadataCleaner != null) {
- metadataCleaner.cancel()
+ Utils.tryLogNonFatalError {
+ metadataCleaner.cancel()
+ }
+ }
+ Utils.tryLogNonFatalError {
+ _cleaner.foreach(_.stop())
+ }
+ Utils.tryLogNonFatalError {
+ _executorAllocationManager.foreach(_.stop())
}
- _cleaner.foreach(_.stop())
- _executorAllocationManager.foreach(_.stop())
if (_dagScheduler != null) {
- _dagScheduler.stop()
+ Utils.tryLogNonFatalError {
+ _dagScheduler.stop()
+ }
_dagScheduler = null
}
if (_listenerBusStarted) {
- listenerBus.stop()
- _listenerBusStarted = false
+ Utils.tryLogNonFatalError {
+ listenerBus.stop()
+ _listenerBusStarted = false
+ }
+ }
+ Utils.tryLogNonFatalError {
+ _eventLogger.foreach(_.stop())
}
- _eventLogger.foreach(_.stop())
if (env != null && _heartbeatReceiver != null) {
- env.rpcEnv.stop(_heartbeatReceiver)
+ Utils.tryLogNonFatalError {
+ env.rpcEnv.stop(_heartbeatReceiver)
+ }
+ }
+ Utils.tryLogNonFatalError {
+ _progressBar.foreach(_.stop())
}
- _progressBar.foreach(_.stop())
_taskScheduler = null
// TODO: Cache.stop()?
if (_env != null) {
- _env.stop()
+ Utils.tryLogNonFatalError {
+ _env.stop()
+ }
SparkEnv.set(null)
}
SparkContext.clearActiveContext()
@@ -2653,7 +2663,7 @@ object SparkContext extends Logging {
val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false)
val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, sc, url)
+ new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager)
} else {
new MesosSchedulerBackend(scheduler, sc, url)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index adfece4d6e7c0..36bc0730688ac 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -73,6 +73,7 @@ class SparkEnv (
val shuffleMemoryManager: ShuffleMemoryManager,
val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
+ val executorCleaner: ExecutorCleaner,
val conf: SparkConf) extends Logging {
// TODO Remove actorSystem
@@ -101,6 +102,7 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
+ executorCleaner.stop()
rpcEnv.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
@@ -324,7 +326,7 @@ object SparkEnv extends Logging {
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
- val shuffleMemoryManager = new ShuffleMemoryManager(conf)
+ val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores)
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
@@ -398,6 +400,8 @@ object SparkEnv extends Logging {
}
new ExecutorMemoryManager(allocator)
}
+ val executorCleaner = new ExecutorCleaner
+ executorCleaner.start()
val envInstance = new SparkEnv(
executorId,
@@ -417,6 +421,7 @@ object SparkEnv extends Logging {
shuffleMemoryManager,
executorMemoryManager,
outputCommitCoordinator,
+ executorCleaner,
conf)
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index b48836d5c8897..63cca80b2d734 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -59,6 +59,14 @@ object TaskContext {
* Unset the thread local TaskContext. Internal to Spark.
*/
protected[spark] def unset(): Unit = taskContext.remove()
+
+ /**
+ * An empty task context that does not represent an actual task.
+ */
+ private[spark] def empty(): TaskContextImpl = {
+ new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty)
+ }
+
}
@@ -179,4 +187,9 @@ abstract class TaskContext extends Serializable {
* accumulator id and the value of the Map is the latest accumulator local value.
*/
private[spark] def collectAccumulators(): Map[Long, Any]
+
+ /**
+ * Accumulators for tracking internal metrics indexed by the name.
+ */
+ private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]]
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 9ee168ae016f8..5df94c6d3a103 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -32,6 +32,7 @@ private[spark] class TaskContextImpl(
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
@transient private val metricsSystem: MetricsSystem,
+ internalAccumulators: Seq[Accumulator[Long]],
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
@@ -114,4 +115,11 @@ private[spark] class TaskContextImpl(
private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
accumulators.mapValues(_.localValue).toMap
}
+
+ private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
+ // Explicitly register internal accumulators here because these are
+ // not captured in the task closure and are already deserialized
+ internalAccumulators.foreach(registerAccumulator)
+ internalAccumulators.map { a => (a.name.get, a) }.toMap
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 48fd3e7e23d52..934d00dc708b9 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
@@ -90,6 +92,10 @@ case class FetchFailed(
*
* `fullStackTrace` is a better representation of the stack trace because it contains the whole
* stack trace including the exception and its causes
+ *
+ * `exception` is the actual exception that caused the task to fail. It may be `None` in
+ * the case that the exception is not in fact serializable. If a task fails more than
+ * once (due to retries), `exception` is that one that caused the last failure.
*/
@DeveloperApi
case class ExceptionFailure(
@@ -97,11 +103,26 @@ case class ExceptionFailure(
description: String,
stackTrace: Array[StackTraceElement],
fullStackTrace: String,
- metrics: Option[TaskMetrics])
+ metrics: Option[TaskMetrics],
+ private val exceptionWrapper: Option[ThrowableSerializationWrapper])
extends TaskFailedReason {
+ /**
+ * `preserveCause` is used to keep the exception itself so it is available to the
+ * driver. This may be set to `false` in the event that the exception is not in fact
+ * serializable.
+ */
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics,
+ if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None)
+ }
+
private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
- this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ this(e, metrics, preserveCause = true)
+ }
+
+ def exception: Option[Throwable] = exceptionWrapper.flatMap {
+ (w: ThrowableSerializationWrapper) => Option(w.exception)
}
override def toErrorString: String =
@@ -127,6 +148,25 @@ case class ExceptionFailure(
}
}
+/**
+ * A class for recovering from exceptions when deserializing a Throwable that was
+ * thrown in user task code. If the Throwable cannot be deserialized it will be null,
+ * but the stacktrace and message will be preserved correctly in SparkException.
+ */
+private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends
+ Serializable with Logging {
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ out.writeObject(exception)
+ }
+ private def readObject(in: ObjectInputStream): Unit = {
+ try {
+ exception = in.readObject().asInstanceOf[Throwable]
+ } catch {
+ case e : Exception => log.warn("Task exception could not be deserialized", e)
+ }
+ }
+}
+
/**
* :: DeveloperApi ::
* The task finished successfully, but the result was lost from the executor's block manager before
diff --git a/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala b/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala
new file mode 100644
index 0000000000000..0dd6d4773dcb6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark
+
+import java.lang.ref.ReferenceQueue
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+
+import org.apache.spark.util.cleanup.{CleanupTask, CleanupTaskWeakReference}
+
+/**
+ * Utility trait that keeps a long running thread for cleaning up weak references
+ * after they are GCed. Currently implemented by ContextCleaner and ExecutorCleaner
+ * only.
+ */
+private[spark] trait WeakReferenceCleaner extends Logging {
+
+ private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
+ with SynchronizedBuffer[CleanupTaskWeakReference]
+
+ private val referenceQueue = new ReferenceQueue[AnyRef]
+
+ private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
+
+ private var stopped = false
+
+ /** Start the cleaner. */
+ def start(): Unit = {
+ cleaningThread.setDaemon(true)
+ cleaningThread.setName(cleanupThreadName())
+ cleaningThread.start()
+ }
+
+ def stop(): Unit = {
+ stopped = true
+ synchronized {
+ // Interrupt the cleaning thread, but wait until the current task has finished before
+ // doing so. This guards against the race condition where a cleaning thread may
+ // potentially clean similarly named variables created by a different SparkContext,
+ // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132).
+ cleaningThread.interrupt()
+ }
+ cleaningThread.join()
+ }
+
+ protected def keepCleaning(): Unit = {
+ while (!stopped) {
+ try {
+ val reference = Option(referenceQueue.remove(WeakReferenceCleaner.REF_QUEUE_POLL_TIMEOUT))
+ .map(_.asInstanceOf[CleanupTaskWeakReference])
+ // Synchronize here to avoid being interrupted on stop()
+ synchronized {
+ reference.map(_.task).foreach { task =>
+ logDebug("Got cleaning task " + task)
+ referenceBuffer -= reference.get
+ handleCleanupForSpecificTask(task)
+ }
+ }
+ } catch {
+ case ie: InterruptedException if stopped => // ignore
+ case e: Exception => logError("Error in cleaning thread", e)
+ }
+ }
+ }
+
+ /** Register an object for cleanup. */
+ protected def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
+ referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
+ }
+
+ protected def handleCleanupForSpecificTask(task: CleanupTask)
+ protected def cleanupThreadName(): String
+}
+
+private object WeakReferenceCleaner {
+ private val REF_QUEUE_POLL_TIMEOUT = 100
+}
diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala
new file mode 100644
index 0000000000000..fa59393c22476
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.annotation
+
+import scala.annotation.StaticAnnotation
+
+/**
+ * A Scala annotation that specifies the Spark version when a definition was added.
+ * Different from the `@since` tag in JavaDoc, this annotation does not require explicit JavaDoc and
+ * hence works for overridden methods that inherit API documentation directly from parents.
+ * The limitation is that it does not show up in the generated Java API documentation.
+ */
+private[spark] class Since(version: String) extends StaticAnnotation
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 55e563ee968be..2a56bf28d7027 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -794,7 +794,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
/**
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
- * by the DAGScheduler's single-threaded actor anyway.
+ * by the DAGScheduler's single-threaded RpcEndpoint anyway.
*/
@transient var socket: Socket = _
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index a5de10fe89c42..14dac4ed28ce3 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -69,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend)
case e: Exception =>
logError(s"Removing $objId failed", e)
writeInt(dos, -1)
+ writeString(dos, s"Removing $objId failed: ${e.getMessage}")
}
- case _ => dos.writeInt(-1)
+ case _ =>
+ dos.writeInt(-1)
+ writeString(dos, s"Error: unknown method $methodName")
}
} else {
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
@@ -146,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend)
}
} catch {
case e: Exception =>
- logError(s"$methodName on $objId failed", e)
+ logError(s"$methodName on $objId failed")
writeInt(dos, -1)
+ // Writing the error message of the cause for the exception. This will be returned
+ // to user in the R process.
+ writeString(dos, Utils.exceptionString(e.getCause))
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
index d53abd3408c55..427b2bc7cbcbb 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -19,6 +19,8 @@ package org.apache.spark.api.r
import java.io.File
+import scala.collection.JavaConversions._
+
import org.apache.spark.{SparkEnv, SparkException}
private[spark] object RUtils {
@@ -26,7 +28,7 @@ private[spark] object RUtils {
* Get the SparkR package path in the local spark distribution.
*/
def localSparkRPackagePath: Option[String] = {
- val sparkHome = sys.env.get("SPARK_HOME")
+ val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.test.home"))
sparkHome.map(
Seq(_, "R", "lib").mkString(File.separator)
)
@@ -46,8 +48,8 @@ private[spark] object RUtils {
(sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode"))
}
- val isYarnCluster = master.contains("yarn") && deployMode == "cluster"
- val isYarnClient = master.contains("yarn") && deployMode == "client"
+ val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster"
+ val isYarnClient = master != null && master.contains("yarn") && deployMode == "client"
// In YARN mode, the SparkR package is distributed as an archive symbolically
// linked to the "sparkr" file in the current directory. Note that this does not apply
@@ -62,4 +64,14 @@ private[spark] object RUtils {
}
}
}
+
+ /** Check if R is installed before running tests that use R commands. */
+ def isRInstalled: Boolean = {
+ try {
+ val builder = new ProcessBuilder(Seq("R", "--version"))
+ builder.start().waitFor() == 0
+ } catch {
+ case e: Exception => false
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 12727de9b4cf3..d8084a57658ad 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -107,6 +107,10 @@ private[deploy] object DeployMessages {
case class MasterChangeAcknowledged(appId: String)
+ case class RequestExecutors(appId: String, requestedTotal: Int)
+
+ case class KillExecutors(appId: String, executorIds: Seq[String])
+
// Master to AppClient
case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 09973a0a2c998..20a9faa1784b7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -27,6 +27,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.Utils
/**
@@ -45,11 +46,16 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
- private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val blockHandler = newShuffleBlockHandler(transportConf)
private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)
private var server: TransportServer = _
+ /** Create a new shuffle block handler. Factored out for subclasses to override. */
+ protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = {
+ new ExternalShuffleBlockHandler(conf)
+ }
+
/** Starts the external shuffle service if the user has configured us to. */
def startIfEnabled() {
if (enabled) {
@@ -70,6 +76,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
server = transportContext.createServer(port, bootstraps)
}
+ /** Clean up all shuffle files associated with an application that has exited. */
+ def applicationRemoved(appId: String): Unit = {
+ blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */)
+ }
+
def stop() {
if (server != null) {
server.close()
@@ -88,6 +99,13 @@ object ExternalShuffleService extends Logging {
private val barrier = new CountDownLatch(1)
def main(args: Array[String]): Unit = {
+ main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm))
+ }
+
+ /** A helper main method that allows the caller to call this with a custom shuffle service. */
+ private[spark] def main(
+ args: Array[String],
+ newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = {
val sparkConf = new SparkConf
Utils.loadDefaultSparkProperties(sparkConf)
val securityManager = new SecurityManager(sparkConf)
@@ -95,7 +113,7 @@ object ExternalShuffleService extends Logging {
// we override this value since this service is started from the command line
// and we assume the user really wants it to be running
sparkConf.set("spark.shuffle.service.enabled", "true")
- server = new ExternalShuffleService(sparkConf, securityManager)
+ server = newShuffleService(sparkConf, securityManager)
server.start()
installShutdownHook()
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 53356addf6edb..83ccaadfe7447 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -73,12 +73,8 @@ class LocalSparkCluster(
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the workers before the master so they don't get upset that it disconnected
- // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
- // This is unfortunate, but for now we just comment it out.
workerRpcEnvs.foreach(_.shutdown())
- // workerActorSystems.foreach(_.awaitTermination())
masterRpcEnvs.foreach(_.shutdown())
- // masterActorSystems.foreach(_.awaitTermination())
masterRpcEnvs.clear()
workerRpcEnvs.clear()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
new file mode 100644
index 0000000000000..ed1e972955679
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io._
+import java.util.jar.JarFile
+import java.util.logging.Level
+import java.util.zip.{ZipEntry, ZipOutputStream}
+
+import scala.collection.JavaConversions._
+
+import com.google.common.io.{ByteStreams, Files}
+
+import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.api.r.RUtils
+import org.apache.spark.util.{RedirectThread, Utils}
+
+private[deploy] object RPackageUtils extends Logging {
+
+ /** The key in the MANIFEST.mf that we look for, in case a jar contains R code. */
+ private final val hasRPackage = "Spark-HasRPackage"
+
+ /** Base of the shell command used in order to install R packages. */
+ private final val baseInstallCmd = Seq("R", "CMD", "INSTALL", "-l")
+
+ /** R source code should exist under R/pkg in a jar. */
+ private final val RJarEntries = "R/pkg"
+
+ /** Documentation on how the R source file layout should be in the jar. */
+ private[deploy] final val RJarDoc =
+ s"""In order for Spark to build R packages that are parts of Spark Packages, there are a few
+ |requirements. The R source code must be shipped in a jar, with additional Java/Scala
+ |classes. The jar must be in the following format:
+ | 1- The Manifest (META-INF/MANIFEST.mf) must contain the key-value: $hasRPackage: true
+ | 2- The standard R package layout must be preserved under R/pkg/ inside the jar. More
+ | information on the standard R package layout can be found in:
+ | http://cran.r-project.org/doc/contrib/Leisch-CreatingPackages.pdf
+ | An example layout is given below. After running `jar tf $$JAR_FILE | sort`:
+ |
+ |META-INF/MANIFEST.MF
+ |R/
+ |R/pkg/
+ |R/pkg/DESCRIPTION
+ |R/pkg/NAMESPACE
+ |R/pkg/R/
+ |R/pkg/R/myRcode.R
+ |org/
+ |org/apache/
+ |...
+ """.stripMargin.trim
+
+ /** Internal method for logging. We log to a printStream in tests, for debugging purposes. */
+ private def print(
+ msg: String,
+ printStream: PrintStream,
+ level: Level = Level.FINE,
+ e: Throwable = null): Unit = {
+ if (printStream != null) {
+ // scalastyle:off println
+ printStream.println(msg)
+ // scalastyle:on println
+ if (e != null) {
+ e.printStackTrace(printStream)
+ }
+ } else {
+ level match {
+ case Level.INFO => logInfo(msg)
+ case Level.WARNING => logWarning(msg)
+ case Level.SEVERE => logError(msg, e)
+ case _ => logDebug(msg)
+ }
+ }
+ }
+
+ /**
+ * Checks the manifest of the Jar whether there is any R source code bundled with it.
+ * Exposed for testing.
+ */
+ private[deploy] def checkManifestForR(jar: JarFile): Boolean = {
+ val manifest = jar.getManifest.getMainAttributes
+ manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true"
+ }
+
+ /**
+ * Runs the standard R package installation code to build the R package from source.
+ * Multiple runs don't cause problems.
+ */
+ private def rPackageBuilder(dir: File, printStream: PrintStream, verbose: Boolean): Boolean = {
+ // this code should be always running on the driver.
+ val pathToSparkR = RUtils.localSparkRPackagePath.getOrElse(
+ throw new SparkException("SPARK_HOME not set. Can't locate SparkR package."))
+ val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator)
+ val installCmd = baseInstallCmd ++ Seq(pathToSparkR, pathToPkg)
+ if (verbose) {
+ print(s"Building R package with the command: $installCmd", printStream)
+ }
+ try {
+ val builder = new ProcessBuilder(installCmd)
+ builder.redirectErrorStream(true)
+ val env = builder.environment()
+ env.clear()
+ val process = builder.start()
+ new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start()
+ process.waitFor() == 0
+ } catch {
+ case e: Throwable =>
+ print("Failed to build R package.", printStream, Level.SEVERE, e)
+ false
+ }
+ }
+
+ /**
+ * Extracts the files under /R in the jar to a temporary directory for building.
+ */
+ private def extractRFolder(jar: JarFile, printStream: PrintStream, verbose: Boolean): File = {
+ val tempDir = Utils.createTempDir(null)
+ val jarEntries = jar.entries()
+ while (jarEntries.hasMoreElements) {
+ val entry = jarEntries.nextElement()
+ val entryRIndex = entry.getName.indexOf(RJarEntries)
+ if (entryRIndex > -1) {
+ val entryPath = entry.getName.substring(entryRIndex)
+ if (entry.isDirectory) {
+ val dir = new File(tempDir, entryPath)
+ if (verbose) {
+ print(s"Creating directory: $dir", printStream)
+ }
+ dir.mkdirs
+ } else {
+ val inStream = jar.getInputStream(entry)
+ val outPath = new File(tempDir, entryPath)
+ Files.createParentDirs(outPath)
+ val outStream = new FileOutputStream(outPath)
+ if (verbose) {
+ print(s"Extracting $entry to $outPath", printStream)
+ }
+ Utils.copyStream(inStream, outStream, closeStreams = true)
+ }
+ }
+ }
+ tempDir
+ }
+
+ /**
+ * Extracts the files under /R in the jar to a temporary directory for building.
+ */
+ private[deploy] def checkAndBuildRPackage(
+ jars: String,
+ printStream: PrintStream = null,
+ verbose: Boolean = false): Unit = {
+ jars.split(",").foreach { jarPath =>
+ val file = new File(Utils.resolveURI(jarPath))
+ if (file.exists()) {
+ val jar = new JarFile(file)
+ if (checkManifestForR(jar)) {
+ print(s"$file contains R source code. Now installing package.", printStream, Level.INFO)
+ val rSource = extractRFolder(jar, printStream, verbose)
+ try {
+ if (!rPackageBuilder(rSource, printStream, verbose)) {
+ print(s"ERROR: Failed to build R package in $file.", printStream)
+ print(RJarDoc, printStream)
+ }
+ } finally {
+ rSource.delete() // clean up
+ }
+ } else {
+ if (verbose) {
+ print(s"$file doesn't contain R source code, skipping...", printStream)
+ }
+ }
+ } else {
+ print(s"WARN: $file resolved as dependency, but not found.", printStream, Level.WARNING)
+ }
+ }
+ }
+
+ private def listFilesRecursively(dir: File, excludePatterns: Seq[String]): Set[File] = {
+ if (!dir.exists()) {
+ Set.empty[File]
+ } else {
+ if (dir.isDirectory) {
+ val subDir = dir.listFiles(new FilenameFilter {
+ override def accept(dir: File, name: String): Boolean = {
+ !excludePatterns.map(name.contains).reduce(_ || _) // exclude files with given pattern
+ }
+ })
+ subDir.flatMap(listFilesRecursively(_, excludePatterns)).toSet
+ } else {
+ Set(dir)
+ }
+ }
+ }
+
+ /** Zips all the libraries found with SparkR in the R/lib directory for distribution with Yarn. */
+ private[deploy] def zipRLibraries(dir: File, name: String): File = {
+ val filesToBundle = listFilesRecursively(dir, Seq(".zip"))
+ // create a zip file from scratch, do not append to existing file.
+ val zipFile = new File(dir, name)
+ zipFile.delete()
+ val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false))
+ try {
+ filesToBundle.foreach { file =>
+ // get the relative paths for proper naming in the zip file
+ val relPath = file.getAbsolutePath.replaceFirst(dir.getAbsolutePath, "")
+ val fis = new FileInputStream(file)
+ val zipEntry = new ZipEntry(relPath)
+ zipOutputStream.putNextEntry(zipEntry)
+ ByteStreams.copy(fis, zipOutputStream)
+ zipOutputStream.closeEntry()
+ fis.close()
+ }
+ } finally {
+ zipOutputStream.close()
+ }
+ zipFile
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index e06b06e06fb4a..7e9dba42bebd8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -34,6 +34,8 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.JobContext
+import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
+import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID}
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.spark.annotation.DeveloperApi
@@ -194,6 +196,18 @@ class SparkHadoopUtil extends Logging {
method.invoke(context).asInstanceOf[Configuration]
}
+ /**
+ * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly
+ * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes
+ * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+
+ * while it's interface in Hadoop 2.+.
+ */
+ def getTaskAttemptIDFromTaskAttemptContext(
+ context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = {
+ val method = context.getClass.getMethod("getTaskAttemptID")
+ method.invoke(context).asInstanceOf[MapReduceTaskAttemptID]
+ }
+
/**
* Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the
* given path points to a file, return a single-element collection containing [[FileStatus]] of
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 0b39ee8fe3ba0..02fa3088eded0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -24,6 +24,7 @@ import java.security.PrivilegedExceptionAction
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
+import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.Path
import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
@@ -37,6 +38,7 @@ import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.repository.file.FileRepository
import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver}
+
import org.apache.spark.api.r.RUtils
import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._
@@ -275,24 +277,27 @@ object SparkSubmit {
// Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
// too for packages that include Python code
- val resolvedMavenCoordinates =
- SparkSubmitUtils.resolveMavenCoordinates(
- args.packages, Option(args.repositories), Option(args.ivyRepoPath))
- if (!resolvedMavenCoordinates.trim.isEmpty) {
- if (args.jars == null || args.jars.trim.isEmpty) {
- args.jars = resolvedMavenCoordinates
+ val exclusions: Seq[String] =
+ if (!StringUtils.isBlank(args.packagesExclusions)) {
+ args.packagesExclusions.split(",")
} else {
- args.jars += s",$resolvedMavenCoordinates"
+ Nil
}
+ val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages,
+ Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions)
+ if (!StringUtils.isBlank(resolvedMavenCoordinates)) {
+ args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates)
if (args.isPython) {
- if (args.pyFiles == null || args.pyFiles.trim.isEmpty) {
- args.pyFiles = resolvedMavenCoordinates
- } else {
- args.pyFiles += s",$resolvedMavenCoordinates"
- }
+ args.pyFiles = mergeFileLists(args.pyFiles, resolvedMavenCoordinates)
}
}
+ // install any R packages that may have been passed through --jars or --packages.
+ // Spark Packages may contain R source code inside the jar.
+ if (args.isR && !StringUtils.isBlank(args.jars)) {
+ RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose)
+ }
+
// Require all python files to be local, so we can add them to the PYTHONPATH
// In YARN cluster mode, python files are distributed as regular files, which can be non-local
if (args.isPython && !isYarnCluster) {
@@ -362,7 +367,8 @@ object SparkSubmit {
if (rPackagePath.isEmpty) {
printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.")
}
- val rPackageFile = new File(rPackagePath.get, SPARKR_PACKAGE_ARCHIVE)
+ val rPackageFile =
+ RPackageUtils.zipRLibraries(new File(rPackagePath.get), SPARKR_PACKAGE_ARCHIVE)
if (!rPackageFile.exists()) {
printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.")
}
@@ -416,7 +422,8 @@ object SparkSubmit {
// Yarn client only
OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
- OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"),
+ OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES,
+ sysProp = "spark.executor.instances"),
OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"),
OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"),
OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"),
@@ -427,7 +434,6 @@ object SparkSubmit {
OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"),
OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"),
OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
- OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"),
OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"),
OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
@@ -736,7 +742,7 @@ object SparkSubmit {
* no files, into a single comma-separated string.
*/
private def mergeFileLists(lists: String*): String = {
- val merged = lists.filter(_ != null)
+ val merged = lists.filterNot(StringUtils.isBlank)
.flatMap(_.split(","))
.mkString(",")
if (merged == "") null else merged
@@ -938,7 +944,7 @@ private[spark] object SparkSubmitUtils {
// are supplied to spark-submit
val alternateIvyCache = ivyPath.getOrElse("")
val packagesDirectory: File =
- if (alternateIvyCache.trim.isEmpty) {
+ if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) {
new File(ivySettings.getDefaultIvyUserDir, "jars")
} else {
ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache))
@@ -988,11 +994,9 @@ private[spark] object SparkSubmitUtils {
addExclusionRules(ivySettings, ivyConfName, md)
// add all supplied maven artifacts as dependencies
addDependenciesToIvy(md, artifacts, ivyConfName)
-
exclusions.foreach { e =>
md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName))
}
-
// resolve dependencies
val rr: ResolveReport = ivy.resolve(md, resolveOptions)
if (rr.hasError) {
@@ -1010,7 +1014,7 @@ private[spark] object SparkSubmitUtils {
}
}
- private def createExclusion(
+ private[deploy] def createExclusion(
coords: String,
ivySettings: IvySettings,
ivyConfName: String): ExcludeRule = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index b3710073e330c..3f3c6627c21fb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var packages: String = null
var repositories: String = null
var ivyRepoPath: String = null
+ var packagesExclusions: String = null
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
@@ -172,6 +173,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull
+ packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull
+ packagesExclusions = Option(packagesExclusions)
+ .orElse(sparkProperties.get("spark.jars.excludes")).orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
numExecutors = Option(numExecutors)
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
@@ -299,6 +303,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| childArgs [${childArgs.mkString(" ")}]
| jars $jars
| packages $packages
+ | packagesExclusions $packagesExclusions
| repositories $repositories
| verbose $verbose
|
@@ -391,6 +396,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
case PACKAGES =>
packages = value
+ case PACKAGES_EXCLUDE =>
+ packagesExclusions = value
+
case REPOSITORIES =>
repositories = value
@@ -482,6 +490,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| maven repo, then maven central and any additional remote
| repositories given by --repositories. The format for the
| coordinates should be groupId:artifactId:version.
+ | --exclude-packages Comma-separated list of groupId:artifactId, to exclude while
+ | resolving the dependencies provided in --packages to avoid
+ | dependency conflicts.
| --repositories Comma-separated list of additional remote repositories to
| search for the maven coordinates given with --packages.
| --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place
@@ -600,5 +611,4 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
System.setErr(currentErr)
}
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 79b251e7e62fe..25ea6925434ab 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -27,7 +27,7 @@ import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.rpc._
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
/**
* Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL,
@@ -197,6 +197,22 @@ private[spark] class AppClient(
sendToMaster(UnregisterApplication(appId))
context.reply(true)
stop()
+
+ case r: RequestExecutors =>
+ master match {
+ case Some(m) => context.reply(m.askWithRetry[Boolean](r))
+ case None =>
+ logWarning("Attempted to request executors before registering with Master.")
+ context.reply(false)
+ }
+
+ case k: KillExecutors =>
+ master match {
+ case Some(m) => context.reply(m.askWithRetry[Boolean](k))
+ case None =>
+ logWarning("Attempted to kill executors before registering with Master.")
+ context.reply(false)
+ }
}
override def onDisconnected(address: RpcAddress): Unit = {
@@ -241,14 +257,15 @@ private[spark] class AppClient(
}
def start() {
- // Just launch an actor; it will call back into the listener.
+ // Just launch an rpcEndpoint; it will call back into the listener.
endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv))
}
def stop() {
if (endpoint != null) {
try {
- endpoint.askWithRetry[Boolean](StopAppClient)
+ val timeout = RpcUtils.askRpcTimeout(conf)
+ timeout.awaitResult(endpoint.ask[Boolean](StopAppClient))
} catch {
case e: TimeoutException =>
logInfo("Stop request to Master timed out; it may already be shut down.")
@@ -256,4 +273,33 @@ private[spark] class AppClient(
endpoint = null
}
}
+
+ /**
+ * Request executors from the Master by specifying the total number desired,
+ * including existing pending and running executors.
+ *
+ * @return whether the request is acknowledged.
+ */
+ def requestTotalExecutors(requestedTotal: Int): Boolean = {
+ if (endpoint != null && appId != null) {
+ endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal))
+ } else {
+ logWarning("Attempted to request executors before driver fully initialized.")
+ false
+ }
+ }
+
+ /**
+ * Kill the given list of executors through the Master.
+ * @return whether the kill request is acknowledged.
+ */
+ def killExecutors(executorIds: Seq[String]): Boolean = {
+ if (endpoint != null && appId != null) {
+ endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds))
+ } else {
+ logWarning("Attempted to kill executors before driver fully initialized.")
+ false
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index e3060ac3fa1a9..53c18ca3ff50c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -272,9 +272,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
* Replay the log files in the list and merge the list of old applications with new ones
*/
private def mergeApplicationListing(logs: Seq[FileStatus]): Unit = {
- val bus = new ReplayListenerBus()
val newAttempts = logs.flatMap { fileStatus =>
try {
+ val bus = new ReplayListenerBus()
val res = replay(fileStatus, bus)
res match {
case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.")
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index a076a9c3f984d..d4f327cc588fe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -30,7 +30,7 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, Applica
UIRoot}
import org.apache.spark.ui.{SparkUI, UIUtils, WebUI}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.{SignalLogger, Utils}
+import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils}
/**
* A web server that renders SparkUIs of completed applications.
@@ -238,7 +238,7 @@ object HistoryServer extends Logging {
val server = new HistoryServer(conf, provider, securityManager, port)
server.bind()
- Utils.addShutdownHook { () => server.stop() }
+ ShutdownHookManager.addShutdownHook { () => server.stop() }
// Wait until the end of the world... or if the HistoryServer process is manually stopped
while(true) { Thread.sleep(Int.MaxValue) }
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index aa54ed9360f36..b40d20f9f7868 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -22,7 +22,6 @@ import java.util.Date
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
@@ -43,6 +42,11 @@ private[spark] class ApplicationInfo(
@transient var endTime: Long = _
@transient var appSource: ApplicationSource = _
+ // A cap on the number of executors this application can have at any given time.
+ // By default, this is infinite. Only after the first allocation request is issued by the
+ // application will this be set to a finite value. This is used for dynamic allocation.
+ @transient private[master] var executorLimit: Int = _
+
@transient private var nextExecutorId: Int = _
init()
@@ -60,6 +64,7 @@ private[spark] class ApplicationInfo(
appSource = new ApplicationSource(this)
nextExecutorId = 0
removedExecutors = new ArrayBuffer[ExecutorDesc]
+ executorLimit = Integer.MAX_VALUE
}
private def newExecutorId(useID: Option[Int] = None): Int = {
@@ -116,6 +121,12 @@ private[spark] class ApplicationInfo(
state != ApplicationState.WAITING && state != ApplicationState.RUNNING
}
+ /**
+ * Return the limit on the number of executors this application can have.
+ * For testing only.
+ */
+ private[deploy] def getExecutorLimit: Int = executorLimit
+
def duration: Long = {
if (endTime != -1) {
endTime - startTime
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index cf77c86d760cf..70f21fbe0de85 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
trait LeaderElectionAgent {
- val masterActor: LeaderElectable
+ val masterInstance: LeaderElectable
def stop() {} // to avoid noops in implementations.
}
@@ -37,7 +37,7 @@ trait LeaderElectable {
}
/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
-private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable)
+private[spark] class MonarchyLeaderAgent(val masterInstance: LeaderElectable)
extends LeaderElectionAgent {
- masterActor.electedLeader()
+ masterInstance.electedLeader()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 51b3f0dead73e..9217202b69a66 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -45,7 +45,7 @@ import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
-private[master] class Master(
+private[deploy] class Master(
override val rpcEnv: RpcEnv,
address: RpcAddress,
webUiPort: Int,
@@ -468,6 +468,13 @@ private[master] class Master(
case BoundPortsRequest => {
context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort))
}
+
+ case RequestExecutors(appId, requestedTotal) =>
+ context.reply(handleRequestExecutors(appId, requestedTotal))
+
+ case KillExecutors(appId, executorIds) =>
+ val formattedExecutorIds = formatExecutorIds(executorIds)
+ context.reply(handleKillExecutors(appId, formattedExecutorIds))
}
override def onDisconnected(address: RpcAddress): Unit = {
@@ -563,32 +570,51 @@ private[master] class Master(
app: ApplicationInfo,
usableWorkers: Array[WorkerInfo],
spreadOutApps: Boolean): Array[Int] = {
- // If the number of cores per executor is not specified, then we can just schedule
- // 1 core at a time since we expect a single executor to be launched on each worker
- val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1)
+ val coresPerExecutor = app.desc.coresPerExecutor
+ val minCoresPerExecutor = coresPerExecutor.getOrElse(1)
+ val oneExecutorPerWorker = coresPerExecutor.isEmpty
val memoryPerExecutor = app.desc.memoryPerExecutorMB
val numUsable = usableWorkers.length
val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker
- val assignedMemory = new Array[Int](numUsable) // Amount of memory to give to each worker
+ val assignedExecutors = new Array[Int](numUsable) // Number of new executors on each worker
var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum)
- var freeWorkers = (0 until numUsable).toIndexedSeq
+ /** Return whether the specified worker can launch an executor for this app. */
def canLaunchExecutor(pos: Int): Boolean = {
- usableWorkers(pos).coresFree - assignedCores(pos) >= coresPerExecutor &&
- usableWorkers(pos).memoryFree - assignedMemory(pos) >= memoryPerExecutor
+ val keepScheduling = coresToAssign >= minCoresPerExecutor
+ val enoughCores = usableWorkers(pos).coresFree - assignedCores(pos) >= minCoresPerExecutor
+
+ // If we allow multiple executors per worker, then we can always launch new executors.
+ // Otherwise, if there is already an executor on this worker, just give it more cores.
+ val launchingNewExecutor = !oneExecutorPerWorker || assignedExecutors(pos) == 0
+ if (launchingNewExecutor) {
+ val assignedMemory = assignedExecutors(pos) * memoryPerExecutor
+ val enoughMemory = usableWorkers(pos).memoryFree - assignedMemory >= memoryPerExecutor
+ val underLimit = assignedExecutors.sum + app.executors.size < app.executorLimit
+ keepScheduling && enoughCores && enoughMemory && underLimit
+ } else {
+ // We're adding cores to an existing executor, so no need
+ // to check memory and executor limits
+ keepScheduling && enoughCores
+ }
}
- while (coresToAssign >= coresPerExecutor && freeWorkers.nonEmpty) {
- freeWorkers = freeWorkers.filter(canLaunchExecutor)
+ // Keep launching executors until no more workers can accommodate any
+ // more executors, or if we have reached this application's limits
+ var freeWorkers = (0 until numUsable).filter(canLaunchExecutor)
+ while (freeWorkers.nonEmpty) {
freeWorkers.foreach { pos =>
var keepScheduling = true
- while (keepScheduling && canLaunchExecutor(pos) && coresToAssign >= coresPerExecutor) {
- coresToAssign -= coresPerExecutor
- assignedCores(pos) += coresPerExecutor
- // If cores per executor is not set, we are assigning 1 core at a time
- // without actually meaning to launch 1 executor for each core assigned
- if (app.desc.coresPerExecutor.isDefined) {
- assignedMemory(pos) += memoryPerExecutor
+ while (keepScheduling && canLaunchExecutor(pos)) {
+ coresToAssign -= minCoresPerExecutor
+ assignedCores(pos) += minCoresPerExecutor
+
+ // If we are launching one executor per worker, then every iteration assigns 1 core
+ // to the executor. Otherwise, every iteration assigns cores to a new executor.
+ if (oneExecutorPerWorker) {
+ assignedExecutors(pos) = 1
+ } else {
+ assignedExecutors(pos) += 1
}
// Spreading out an application means spreading out its executors across as
@@ -600,6 +626,7 @@ private[master] class Master(
}
}
}
+ freeWorkers = freeWorkers.filter(canLaunchExecutor)
}
assignedCores
}
@@ -785,9 +812,7 @@ private[master] class Master(
rebuildSparkUI(app)
for (exec <- app.executors.values) {
- exec.worker.removeExecutor(exec)
- exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id))
- exec.state = ExecutorState.KILLED
+ killExecutor(exec)
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
@@ -803,6 +828,87 @@ private[master] class Master(
}
}
+ /**
+ * Handle a request to set the target number of executors for this application.
+ *
+ * If the executor limit is adjusted upwards, new executors will be launched provided
+ * that there are workers with sufficient resources. If it is adjusted downwards, however,
+ * we do not kill existing executors until we explicitly receive a kill request.
+ *
+ * @return whether the application has previously registered with this Master.
+ */
+ private def handleRequestExecutors(appId: String, requestedTotal: Int): Boolean = {
+ idToApp.get(appId) match {
+ case Some(appInfo) =>
+ logInfo(s"Application $appId requested to set total executors to $requestedTotal.")
+ appInfo.executorLimit = requestedTotal
+ schedule()
+ true
+ case None =>
+ logWarning(s"Unknown application $appId requested $requestedTotal total executors.")
+ false
+ }
+ }
+
+ /**
+ * Handle a kill request from the given application.
+ *
+ * This method assumes the executor limit has already been adjusted downwards through
+ * a separate [[RequestExecutors]] message, such that we do not launch new executors
+ * immediately after the old ones are removed.
+ *
+ * @return whether the application has previously registered with this Master.
+ */
+ private def handleKillExecutors(appId: String, executorIds: Seq[Int]): Boolean = {
+ idToApp.get(appId) match {
+ case Some(appInfo) =>
+ logInfo(s"Application $appId requests to kill executors: " + executorIds.mkString(", "))
+ val (known, unknown) = executorIds.partition(appInfo.executors.contains)
+ known.foreach { executorId =>
+ val desc = appInfo.executors(executorId)
+ appInfo.removeExecutor(desc)
+ killExecutor(desc)
+ }
+ if (unknown.nonEmpty) {
+ logWarning(s"Application $appId attempted to kill non-existent executors: "
+ + unknown.mkString(", "))
+ }
+ schedule()
+ true
+ case None =>
+ logWarning(s"Unregistered application $appId requested us to kill executors!")
+ false
+ }
+ }
+
+ /**
+ * Cast the given executor IDs to integers and filter out the ones that fail.
+ *
+ * All executors IDs should be integers since we launched these executors. However,
+ * the kill interface on the driver side accepts arbitrary strings, so we need to
+ * handle non-integer executor IDs just to be safe.
+ */
+ private def formatExecutorIds(executorIds: Seq[String]): Seq[Int] = {
+ executorIds.flatMap { executorId =>
+ try {
+ Some(executorId.toInt)
+ } catch {
+ case e: NumberFormatException =>
+ logError(s"Encountered executor with a non-integer ID: $executorId. Ignoring")
+ None
+ }
+ }
+ }
+
+ /**
+ * Ask the worker on which the specified executor is launched to kill the executor.
+ */
+ private def killExecutor(exec: ExecutorDesc): Unit = {
+ exec.worker.removeExecutor(exec)
+ exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id))
+ exec.state = ExecutorState.KILLED
+ }
+
/**
* Rebuild a new SparkUI from the given application's event logs.
* Return the UI if successful, else None
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index 68c937188b333..a952cee36eb44 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -38,5 +38,5 @@ private[master] object MasterMessages {
case object BoundPortsRequest
- case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int])
+ case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int])
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 6fdff86f66e01..d317206a614fb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -22,7 +22,7 @@ import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
import org.apache.spark.deploy.SparkCuratorUtil
-private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable,
+private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable,
conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
@@ -73,10 +73,10 @@ private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElecta
private def updateLeadershipStatus(isLeader: Boolean) {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
- masterActor.electedLeader()
+ masterInstance.electedLeader()
} else if (!isLeader && status == LeadershipStatus.LEADER) {
status = LeadershipStatus.NOT_LEADER
- masterActor.revokedLeadership()
+ masterInstance.revokedLeadership()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
new file mode 100644
index 0000000000000..061857476a8a0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.mesos
+
+import java.net.SocketAddress
+
+import scala.collection.mutable
+
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.deploy.ExternalShuffleService
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
+import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
+import org.apache.spark.network.util.TransportConf
+
+/**
+ * An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
+ * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
+ */
+private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
+ extends ExternalShuffleBlockHandler(transportConf) with Logging {
+
+ // Stores a map of driver socket addresses to app ids
+ private val connectedApps = new mutable.HashMap[SocketAddress, String]
+
+ protected override def handleMessage(
+ message: BlockTransferMessage,
+ client: TransportClient,
+ callback: RpcResponseCallback): Unit = {
+ message match {
+ case RegisterDriverParam(appId) =>
+ val address = client.getSocketAddress
+ logDebug(s"Received registration request from app $appId (remote address $address).")
+ if (connectedApps.contains(address)) {
+ val existingAppId = connectedApps(address)
+ if (!existingAppId.equals(appId)) {
+ logError(s"A new app '$appId' has connected to existing address $address, " +
+ s"removing previously registered app '$existingAppId'.")
+ applicationRemoved(existingAppId, true)
+ }
+ }
+ connectedApps(address) = appId
+ callback.onSuccess(new Array[Byte](0))
+ case _ => super.handleMessage(message, client, callback)
+ }
+ }
+
+ /**
+ * On connection termination, clean up shuffle files written by the associated application.
+ */
+ override def connectionTerminated(client: TransportClient): Unit = {
+ val address = client.getSocketAddress
+ if (connectedApps.contains(address)) {
+ val appId = connectedApps(address)
+ logInfo(s"Application $appId disconnected (address was $address).")
+ applicationRemoved(appId, true /* cleanupLocalDirs */)
+ connectedApps.remove(address)
+ } else {
+ logWarning(s"Unknown $address disconnected.")
+ }
+ }
+
+ /** An extractor object for matching [[RegisterDriver]] message. */
+ private object RegisterDriverParam {
+ def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
+ }
+}
+
+/**
+ * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers
+ * to associate with. This allows the shuffle service to detect when a driver is terminated
+ * and can clean up the associated shuffle files.
+ */
+private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager)
+ extends ExternalShuffleService(conf, securityManager) {
+
+ protected override def newShuffleBlockHandler(
+ conf: TransportConf): ExternalShuffleBlockHandler = {
+ new MesosExternalShuffleBlockHandler(conf)
+ }
+}
+
+private[spark] object MesosExternalShuffleService extends Logging {
+
+ def main(args: Array[String]): Unit = {
+ ExternalShuffleService.main(args,
+ (conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm))
+ }
+}
+
+
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 29a5042285578..ab3fea475c2a5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -28,7 +28,7 @@ import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.util.logging.FileAppender
/**
@@ -70,7 +70,8 @@ private[deploy] class ExecutorRunner(
}
workerThread.start()
// Shutdown hook that kills actors on shutdown.
- shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) }
+ shutdownHook = ShutdownHookManager.addShutdownHook { () =>
+ killProcess(Some("Worker shutting down")) }
}
/**
@@ -102,7 +103,7 @@ private[deploy] class ExecutorRunner(
workerThread = null
state = ExecutorState.KILLED
try {
- Utils.removeShutdownHook(shutdownHook)
+ ShutdownHookManager.removeShutdownHook(shutdownHook)
} catch {
case e: IllegalStateException => None
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 82e9578bbcba5..79b1536d94016 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -25,7 +25,7 @@ import java.util.concurrent._
import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
import scala.collection.JavaConversions._
-import scala.collection.mutable.{HashMap, HashSet}
+import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap}
import scala.concurrent.ExecutionContext
import scala.util.Random
import scala.util.control.NonFatal
@@ -40,7 +40,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rpc._
import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}
-private[worker] class Worker(
+private[deploy] class Worker(
override val rpcEnv: RpcEnv,
webUiPort: Int,
cores: Int,
@@ -115,13 +115,18 @@ private[worker] class Worker(
}
var workDir: File = null
- val finishedExecutors = new HashMap[String, ExecutorRunner]
+ val finishedExecutors = new LinkedHashMap[String, ExecutorRunner]
val drivers = new HashMap[String, DriverRunner]
val executors = new HashMap[String, ExecutorRunner]
- val finishedDrivers = new HashMap[String, DriverRunner]
+ val finishedDrivers = new LinkedHashMap[String, DriverRunner]
val appDirectories = new HashMap[String, Seq[String]]
val finishedApps = new HashSet[String]
+ val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors",
+ WorkerWebUI.DEFAULT_RETAINED_EXECUTORS)
+ val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers",
+ WorkerWebUI.DEFAULT_RETAINED_DRIVERS)
+
// The shuffle service is not actually started unless configured.
private val shuffleService = new ExternalShuffleService(conf, securityMgr)
@@ -223,7 +228,7 @@ private[worker] class Worker(
/**
* Re-register with the master because a network failure or a master failure has occurred.
* If the re-registration attempt threshold is exceeded, the worker exits with error.
- * Note that for thread-safety this should only be called from the actor.
+ * Note that for thread-safety this should only be called from the rpcEndpoint.
*/
private def reregisterWithMaster(): Unit = {
Utils.tryOrExit {
@@ -360,7 +365,8 @@ private[worker] class Worker(
if (connected) { sendToMaster(Heartbeat(workerId, self)) }
case WorkDirCleanup =>
- // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
+ // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker
+ // rpcEndpoint.
// Copy ids so that it can be used in the cleanup thread.
val appIds = executors.values.map(_.appId).toSet
val cleanupFuture = concurrent.future {
@@ -422,7 +428,9 @@ private[worker] class Worker(
// application finishes.
val appLocalDirs = appDirectories.get(appId).getOrElse {
Utils.getOrCreateLocalRootDirs(conf).map { dir =>
- Utils.createDirectory(dir, namePrefix = "executor").getAbsolutePath()
+ val appDir = Utils.createDirectory(dir, namePrefix = "executor")
+ Utils.chmod700(appDir)
+ appDir.getAbsolutePath()
}.toSeq
}
appDirectories(appId) = appLocalDirs
@@ -461,25 +469,7 @@ private[worker] class Worker(
}
case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- sendToMaster(executorStateChanged)
- val fullId = appId + "/" + execId
- if (ExecutorState.isFinished(state)) {
- executors.get(fullId) match {
- case Some(executor) =>
- logInfo("Executor " + fullId + " finished with state " + state +
- message.map(" message " + _).getOrElse("") +
- exitStatus.map(" exitStatus " + _).getOrElse(""))
- executors -= fullId
- finishedExecutors(fullId) = executor
- coresUsed -= executor.cores
- memoryUsed -= executor.memory
- case None =>
- logInfo("Unknown Executor " + fullId + " finished with state " + state +
- message.map(" message " + _).getOrElse("") +
- exitStatus.map(" exitStatus " + _).getOrElse(""))
- }
- maybeCleanupApplication(appId)
- }
+ handleExecutorStateChanged(executorStateChanged)
case KillExecutor(masterUrl, appId, execId) =>
if (masterUrl != activeMasterUrl) {
@@ -523,24 +513,8 @@ private[worker] class Worker(
}
}
- case driverStageChanged @ DriverStateChanged(driverId, state, exception) => {
- state match {
- case DriverState.ERROR =>
- logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}")
- case DriverState.FAILED =>
- logWarning(s"Driver $driverId exited with failure")
- case DriverState.FINISHED =>
- logInfo(s"Driver $driverId exited successfully")
- case DriverState.KILLED =>
- logInfo(s"Driver $driverId was killed by user")
- case _ =>
- logDebug(s"Driver $driverId changed state to $state")
- }
- sendToMaster(driverStageChanged)
- val driver = drivers.remove(driverId).get
- finishedDrivers(driverId) = driver
- memoryUsed -= driver.driverDesc.mem
- coresUsed -= driver.driverDesc.cores
+ case driverStateChanged @ DriverStateChanged(driverId, state, exception) => {
+ handleDriverStateChanged(driverStateChanged)
}
case ReregisterWithMaster =>
@@ -582,6 +556,7 @@ private[worker] class Worker(
Utils.deleteRecursively(new File(dir))
}
}
+ shuffleService.applicationRemoved(id)
}
}
@@ -614,9 +589,84 @@ private[worker] class Worker(
webUi.stop()
metricsSystem.stop()
}
+
+ private def trimFinishedExecutorsIfNecessary(): Unit = {
+ // do not need to protect with locks since both WorkerPage and Restful server get data through
+ // thread-safe RpcEndPoint
+ if (finishedExecutors.size > retainedExecutors) {
+ finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach {
+ case (executorId, _) => finishedExecutors.remove(executorId)
+ }
+ }
+ }
+
+ private def trimFinishedDriversIfNecessary(): Unit = {
+ // do not need to protect with locks since both WorkerPage and Restful server get data through
+ // thread-safe RpcEndPoint
+ if (finishedDrivers.size > retainedDrivers) {
+ finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach {
+ case (driverId, _) => finishedDrivers.remove(driverId)
+ }
+ }
+ }
+
+ private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = {
+ val driverId = driverStateChanged.driverId
+ val exception = driverStateChanged.exception
+ val state = driverStateChanged.state
+ state match {
+ case DriverState.ERROR =>
+ logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}")
+ case DriverState.FAILED =>
+ logWarning(s"Driver $driverId exited with failure")
+ case DriverState.FINISHED =>
+ logInfo(s"Driver $driverId exited successfully")
+ case DriverState.KILLED =>
+ logInfo(s"Driver $driverId was killed by user")
+ case _ =>
+ logDebug(s"Driver $driverId changed state to $state")
+ }
+ sendToMaster(driverStateChanged)
+ val driver = drivers.remove(driverId).get
+ finishedDrivers(driverId) = driver
+ trimFinishedDriversIfNecessary()
+ memoryUsed -= driver.driverDesc.mem
+ coresUsed -= driver.driverDesc.cores
+ }
+
+ private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged):
+ Unit = {
+ sendToMaster(executorStateChanged)
+ val state = executorStateChanged.state
+ if (ExecutorState.isFinished(state)) {
+ val appId = executorStateChanged.appId
+ val fullId = appId + "/" + executorStateChanged.execId
+ val message = executorStateChanged.message
+ val exitStatus = executorStateChanged.exitStatus
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Executor " + fullId + " finished with state " + state +
+ message.map(" message " + _).getOrElse("") +
+ exitStatus.map(" exitStatus " + _).getOrElse(""))
+ executors -= fullId
+ finishedExecutors(fullId) = executor
+ trimFinishedExecutorsIfNecessary()
+ coresUsed -= executor.cores
+ memoryUsed -= executor.memory
+ case None =>
+ logInfo("Unknown Executor " + fullId + " finished with state " + state +
+ message.map(" message " + _).getOrElse("") +
+ exitStatus.map(" exitStatus " + _).getOrElse(""))
+ }
+ maybeCleanupApplication(appId)
+ }
+ }
}
private[deploy] object Worker extends Logging {
+ val SYSTEM_NAME = "sparkWorker"
+ val ENDPOINT_NAME = "Worker"
+
def main(argStrings: Array[String]) {
SignalLogger.register(log)
val conf = new SparkConf
@@ -637,14 +687,13 @@ private[deploy] object Worker extends Logging {
workerNumber: Option[Int] = None,
conf: SparkConf = new SparkConf): RpcEnv = {
- // The LocalSparkCluster runs multiple local sparkWorkerX actor systems
- val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
- val actorName = "Worker"
+ // The LocalSparkCluster runs multiple local sparkWorkerX RPC Environments
+ val systemName = SYSTEM_NAME + workerNumber.map(_.toString).getOrElse("")
val securityMgr = new SecurityManager(conf)
val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr)
val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_))
- rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses,
- systemName, actorName, workDir, conf, securityMgr))
+ rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory,
+ masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr))
rpcEnv
}
@@ -669,5 +718,4 @@ private[deploy] object Worker extends Logging {
cmd
}
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index fae5640b9a213..735c4f0927150 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -43,7 +43,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin
private[deploy] def setTesting(testing: Boolean) = isTesting = testing
private var isTesting = false
- // Lets us filter events only from the worker's actor system
+ // Lets filter events only from the worker's rpc system
private val expectedAddress = RpcAddress.fromURIString(workerUrl)
private def isWorker(address: RpcAddress) = expectedAddress == address
@@ -62,7 +62,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
if (isWorker(remoteAddress)) {
// This log message will never be seen
- logError(s"Lost connection to worker actor $workerUrl. Exiting.")
+ logError(s"Lost connection to worker rpc endpoint $workerUrl. Exiting.")
exitNonZero()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 334a5b10142aa..709a27233598c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -53,6 +53,8 @@ class WorkerWebUI(
}
}
-private[ui] object WorkerWebUI {
+private[worker] object WorkerWebUI {
val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR
+ val DEFAULT_RETAINED_DRIVERS = 1000
+ val DEFAULT_RETAINED_EXECUTORS = 1000
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 7bc7fce7ae8dd..42a85e42ea2b6 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import java.io.File
+import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
@@ -249,6 +249,7 @@ private[spark] class Executor(
m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
+ m.updateAccumulators()
}
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
@@ -300,11 +301,20 @@ private[spark] class Executor(
task.metrics.map { m =>
m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
+ m.updateAccumulators()
m
}
}
- val taskEndReason = new ExceptionFailure(t, metrics)
- execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))
+ val serializedTaskEndReason = {
+ try {
+ ser.serialize(new ExceptionFailure(t, metrics))
+ } catch {
+ case _: NotSerializableException =>
+ // t is not serializable so just send the stacktrace
+ ser.serialize(new ExceptionFailure(t, metrics, false))
+ }
+ }
+ execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 87df42748be44..f405b732e4725 100644
--- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.{Logging, SparkEnv, TaskContext}
import org.apache.spark.util.{Utils => SparkUtils}
@@ -93,7 +94,7 @@ object SparkHadoopMapRedUtil extends Logging {
splitId: Int,
attemptId: Int): Unit = {
- val mrTaskAttemptID = mrTaskContext.getTaskAttemptID
+ val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext)
// Called after we have decided to commit
def performCommit(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index e17bd47905d7a..72fe215dae73e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -17,156 +17,31 @@
package org.apache.spark.rdd
-import java.io.IOException
-
import scala.reflect.ClassTag
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.{Partition, SparkContext, TaskContext}
+/**
+ * An RDD partition used to recover checkpointed data.
+ */
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
/**
- * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ * An RDD that recovers checkpointed data from storage.
*/
-private[spark]
-class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
+private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext)
extends RDD[T](sc, Nil) {
- private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
-
- @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
-
- override def getCheckpointFile: Option[String] = Some(checkpointPath)
-
- override def getPartitions: Array[Partition] = {
- val cpath = new Path(checkpointPath)
- val numPartitions =
- // listStatus can throw exception if path does not exist.
- if (fs.exists(cpath)) {
- val dirContents = fs.listStatus(cpath).map(_.getPath)
- val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted
- val numPart = partitionFiles.length
- if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
- throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
- }
- numPart
- } else 0
-
- Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val status = fs.getFileStatus(new Path(checkpointPath,
- CheckpointRDD.splitIdToFile(split.index)))
- val locations = fs.getFileBlockLocations(status, 0, status.getLen)
- locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
- }
-
- override def compute(split: Partition, context: TaskContext): Iterator[T] = {
- val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
- CheckpointRDD.readFromFile(file, broadcastedConf, context)
- }
-
// CheckpointRDD should not be checkpointed again
- override def checkpoint(): Unit = { }
override def doCheckpoint(): Unit = { }
-}
-
-private[spark] object CheckpointRDD extends Logging {
- def splitIdToFile(splitId: Int): String = {
- "part-%05d".format(splitId)
- }
-
- def writeToFile[T: ClassTag](
- path: String,
- broadcastedConf: Broadcast[SerializableConfiguration],
- blockSize: Int = -1
- )(ctx: TaskContext, iterator: Iterator[T]) {
- val env = SparkEnv.get
- val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(broadcastedConf.value.value)
-
- val finalOutputName = splitIdToFile(ctx.partitionId)
- val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath =
- new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
-
- if (fs.exists(tempOutputPath)) {
- throw new IOException("Checkpoint failed: temporary path " +
- tempOutputPath + " already exists")
- }
- val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
-
- val fileOutputStream = if (blockSize < 0) {
- fs.create(tempOutputPath, false, bufferSize)
- } else {
- // This is mainly for testing purpose
- fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
- }
- val serializer = env.serializer.newInstance()
- val serializeStream = serializer.serializeStream(fileOutputStream)
- Utils.tryWithSafeFinally {
- serializeStream.writeAll(iterator)
- } {
- serializeStream.close()
- }
-
- if (!fs.rename(tempOutputPath, finalOutputPath)) {
- if (!fs.exists(finalOutputPath)) {
- logInfo("Deleting tempOutputPath " + tempOutputPath)
- fs.delete(tempOutputPath, false)
- throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptNumber + " and final output path does not exist")
- } else {
- // Some other copy of this task must've finished before us and renamed it
- logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
- fs.delete(tempOutputPath, false)
- }
- }
- }
-
- def readFromFile[T](
- path: Path,
- broadcastedConf: Broadcast[SerializableConfiguration],
- context: TaskContext
- ): Iterator[T] = {
- val env = SparkEnv.get
- val fs = path.getFileSystem(broadcastedConf.value.value)
- val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
- val fileInputStream = fs.open(path, bufferSize)
- val serializer = env.serializer.newInstance()
- val deserializeStream = serializer.deserializeStream(fileInputStream)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => deserializeStream.close())
-
- deserializeStream.asIterator.asInstanceOf[Iterator[T]]
- }
+ override def checkpoint(): Unit = { }
+ override def localCheckpoint(): this.type = this
- // Test whether CheckpointRDD generate expected number of partitions despite
- // each split file having multiple blocks. This needs to be run on a
- // cluster (mesos or standalone) using HDFS.
- def main(args: Array[String]) {
- import org.apache.spark._
+ // Note: There is a bug in MiMa that complains about `AbstractMethodProblem`s in the
+ // base [[org.apache.spark.rdd.RDD]] class if we do not override the following methods.
+ // scalastyle:off
+ protected override def getPartitions: Array[Partition] = ???
+ override def compute(p: Partition, tc: TaskContext): Iterator[T] = ???
+ // scalastyle:on
- val Array(cluster, hdfsPath) = args
- val env = SparkEnv.get
- val sc = new SparkContext(cluster, "CheckpointRDD Test")
- val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
- val path = new Path(hdfsPath, "temp")
- val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf())
- val fs = path.getFileSystem(conf)
- val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf))
- sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _)
- val cpRDD = new CheckpointRDD[Int](sc, path.toString)
- assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
- assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
- fs.delete(path, true)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 130b58882d8ee..9c617fc719cb5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
-import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
import org.apache.spark.util.Utils
@@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index f1c17369cb48c..e1f8719eead02 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
-import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils}
+import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils}
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
import org.apache.spark.storage.StorageLevel
@@ -274,7 +274,7 @@ class HadoopRDD[K, V](
}
} catch {
case e: Exception => {
- if (!Utils.inShutdown()) {
+ if (!ShutdownHookManager.inShutdown()) {
logWarning("Exception in RecordReader.close()", e)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
new file mode 100644
index 0000000000000..daa5779d688cc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.storage.RDDBlockId
+
+/**
+ * A dummy CheckpointRDD that exists to provide informative error messages during failures.
+ *
+ * This is simply a placeholder because the original checkpointed RDD is expected to be
+ * fully cached. Only if an executor fails or if the user explicitly unpersists the original
+ * RDD will Spark ever attempt to compute this CheckpointRDD. When this happens, however,
+ * we must provide an informative error message.
+ *
+ * @param sc the active SparkContext
+ * @param rddId the ID of the checkpointed RDD
+ * @param numPartitions the number of partitions in the checkpointed RDD
+ */
+private[spark] class LocalCheckpointRDD[T: ClassTag](
+ @transient sc: SparkContext,
+ rddId: Int,
+ numPartitions: Int)
+ extends CheckpointRDD[T](sc) {
+
+ def this(rdd: RDD[T]) {
+ this(rdd.context, rdd.id, rdd.partitions.size)
+ }
+
+ protected override def getPartitions: Array[Partition] = {
+ (0 until numPartitions).toArray.map { i => new CheckpointRDDPartition(i) }
+ }
+
+ /**
+ * Throw an exception indicating that the relevant block is not found.
+ *
+ * This should only be called if the original RDD is explicitly unpersisted or if an
+ * executor is lost. Under normal circumstances, however, the original RDD (our child)
+ * is expected to be fully cached and so all partitions should already be computed and
+ * available in the block storage.
+ */
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+ throw new SparkException(
+ s"Checkpoint block ${RDDBlockId(rddId, partition.index)} not found! Either the executor " +
+ s"that originally checkpointed this partition is no longer alive, or the original RDD is " +
+ s"unpersisted. If this problem persists, you may consider using `rdd.checkpoint()` " +
+ s"instead, which is slower than local checkpointing but more fault-tolerant.")
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
new file mode 100644
index 0000000000000..d6fad896845f6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * An implementation of checkpointing implemented on top of Spark's caching layer.
+ *
+ * Local checkpointing trades off fault tolerance for performance by skipping the expensive
+ * step of saving the RDD data to a reliable and fault-tolerant storage. Instead, the data
+ * is written to the local, ephemeral block storage that lives in each executor. This is useful
+ * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX).
+ */
+private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends RDDCheckpointData[T](rdd) with Logging {
+
+ /**
+ * Ensure the RDD is fully cached so the partitions can be recovered later.
+ */
+ protected override def doCheckpoint(): CheckpointRDD[T] = {
+ val level = rdd.getStorageLevel
+
+ // Assume storage level uses disk; otherwise memory eviction may cause data loss
+ assume(level.useDisk, s"Storage level $level is not appropriate for local checkpointing")
+
+ // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we
+ // must cache any missing partitions. TODO: avoid running another job here (SPARK-8582).
+ val action = (tc: TaskContext, iterator: Iterator[T]) => Utils.getIteratorSize(iterator)
+ val missingPartitionIndices = rdd.partitions.map(_.index).filter { i =>
+ !SparkEnv.get.blockManager.master.contains(RDDBlockId(rdd.id, i))
+ }
+ if (missingPartitionIndices.nonEmpty) {
+ rdd.sparkContext.runJob(rdd, action, missingPartitionIndices)
+ }
+
+ new LocalCheckpointRDD[T](rdd)
+ }
+
+}
+
+private[spark] object LocalRDDCheckpointData {
+
+ val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
+
+ /**
+ * Transform the specified storage level to one that uses disk.
+ *
+ * This guarantees that the RDD can be recomputed multiple times correctly as long as
+ * executors do not fail. Otherwise, if the RDD is cached in memory only, for instance,
+ * the checkpoint data will be lost if the relevant block is evicted from memory.
+ *
+ * This method is idempotent.
+ */
+ def transformStorageLevel(level: StorageLevel): StorageLevel = {
+ // If this RDD is to be cached off-heap, fail fast since we cannot provide any
+ // correctness guarantees about subsequent computations after the first one
+ if (level.useOffHeap) {
+ throw new SparkException("Local checkpointing is not compatible with off-heap caching.")
+ }
+
+ StorageLevel(useDisk = true, level.useMemory, level.deserialized, level.replication)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index a838aac6e8d1a..4312d3a417759 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -21,6 +21,9 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
+/**
+ * An RDD that applies the provided function to every partition of the parent RDD.
+ */
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
new file mode 100644
index 0000000000000..b475bd8d79f85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, Partitioner, TaskContext}
+
+/**
+ * An RDD that applies a user provided function to every partition of the parent RDD, and
+ * additionally allows the user to prepare each partition before computing the parent partition.
+ */
+private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
+ prev: RDD[T],
+ preparePartition: () => M,
+ executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false)
+ extends RDD[U](prev) {
+
+ override val partitioner: Option[Partitioner] = {
+ if (preservesPartitioning) firstParent[T].partitioner else None
+ }
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ /**
+ * Prepare a partition before computing it from its parent.
+ */
+ override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+ val preparedArgument = preparePartition()
+ val parentIterator = firstParent[T].iterator(partition, context)
+ executePartition(context, partition.index, preparedArgument, parentIterator)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index f83a051f5da11..6a9c004d65cff 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -33,7 +33,7 @@ import org.apache.spark._
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.StorageLevel
@@ -186,7 +186,7 @@ class NewHadoopRDD[K, V](
}
} catch {
case e: Exception => {
- if (!Utils.inShutdown()) {
+ if (!ShutdownHookManager.inShutdown()) {
logWarning("Exception in RecordReader.close()", e)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 326fafb230a40..5e89cbd1eaefb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -45,7 +45,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.collection.{ExternalSorter, ExternalList, SizeTrackingCompactBuffer, CompactBuffer}
import org.apache.spark.util.random.StratifiedSamplingUtils
/**
@@ -463,12 +463,26 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
- val createCombiner = (v: V) => CompactBuffer(v)
- val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v
- val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2
- val bufs = combineByKey[CompactBuffer[V]](
- createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false)
- bufs.asInstanceOf[RDD[(K, Iterable[V])]]
+ val createCombiner = (v: V) => ExternalList(v)
+ val mergeValue = (buf: ExternalList[V], v: V) => buf += v
+ val mergeCombiners = (c1: ExternalList[V], c2: ExternalList[V]) => {
+ c2.foreach(c => c1 += c)
+ c1
+ }
+ val aggregator = new Aggregator[K, V, ExternalList[V]](createCombiner,
+ mergeValue, mergeCombiners)
+ val shuffledRdd = if (self.partitioner != partitioner) {
+ self.partitionBy(partitioner)
+ } else {
+ self
+ }
+ def groupOnPartition(iterator: Iterator[(K, V)]): Iterator[(K, Iterable[V])] = {
+ val sorter = new ExternalSorter[K, V, ExternalList[V]](aggregator = Some(aggregator))
+ sorter.insertAll(iterator)
+ sorter.iterator.map(keyAndGroup => (keyAndGroup._1, keyAndGroup._2.asInstanceOf[Iterable[V]]))
+ }
+
+ shuffledRdd.mapPartitions(groupOnPartition(_), preservesPartitioning = true)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6d61d227382d7..081c721f23687 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -149,23 +149,43 @@ abstract class RDD[T: ClassTag](
}
/**
- * Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. This can only be used to assign a new storage level if the RDD does not
- * have a storage level set yet..
+ * Mark this RDD for persisting using the specified level.
+ *
+ * @param newLevel the target storage level
+ * @param allowOverride whether to override any existing level with the new one
*/
- def persist(newLevel: StorageLevel): this.type = {
+ private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
// TODO: Handle changes of StorageLevel
- if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+ if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
throw new UnsupportedOperationException(
"Cannot change storage level of an RDD after it was already assigned a level")
}
- sc.persistRDD(this)
- // Register the RDD with the ContextCleaner for automatic GC-based cleanup
- sc.cleaner.foreach(_.registerRDDForCleanup(this))
+ // If this is the first time this RDD is marked for persisting, register it
+ // with the SparkContext for cleanups and accounting. Do this only once.
+ if (storageLevel == StorageLevel.NONE) {
+ sc.cleaner.foreach(_.registerRDDForCleanup(this))
+ sc.persistRDD(this)
+ }
storageLevel = newLevel
this
}
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet. Local checkpointing is an exception.
+ */
+ def persist(newLevel: StorageLevel): this.type = {
+ if (isLocallyCheckpointed) {
+ // This means the user previously called localCheckpoint(), which should have already
+ // marked this RDD for persisting. Here we should override the old storage level with
+ // one that is explicitly requested by the user (after adapting it to use disk).
+ persist(LocalRDDCheckpointData.transformStorageLevel(newLevel), allowOverride = true)
+ } else {
+ persist(newLevel, allowOverride = false)
+ }
+ }
+
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): this.type = persist(StorageLevel.MEMORY_ONLY)
@@ -1448,33 +1468,99 @@ abstract class RDD[T: ClassTag](
/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
- * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * directory set with `SparkContext#setCheckpointDir` and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
- def checkpoint(): Unit = {
+ def checkpoint(): Unit = RDDCheckpointData.synchronized {
+ // NOTE: we use a global lock here due to complexities downstream with ensuring
+ // children RDD partitions point to the correct parent partitions. In the future
+ // we should revisit this consideration.
if (context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
- // NOTE: we use a global lock here due to complexities downstream with ensuring
- // children RDD partitions point to the correct parent partitions. In the future
- // we should revisit this consideration.
- RDDCheckpointData.synchronized {
- checkpointData = Some(new RDDCheckpointData(this))
- }
+ checkpointData = Some(new ReliableRDDCheckpointData(this))
+ }
+ }
+
+ /**
+ * Mark this RDD for local checkpointing using Spark's existing caching layer.
+ *
+ * This method is for users who wish to truncate RDD lineages while skipping the expensive
+ * step of replicating the materialized data in a reliable distributed file system. This is
+ * useful for RDDs with long lineages that need to be truncated periodically (e.g. GraphX).
+ *
+ * Local checkpointing sacrifices fault-tolerance for performance. In particular, checkpointed
+ * data is written to ephemeral local storage in the executors instead of to a reliable,
+ * fault-tolerant storage. The effect is that if an executor fails during the computation,
+ * the checkpointed data may no longer be accessible, causing an irrecoverable job failure.
+ *
+ * This is NOT safe to use with dynamic allocation, which removes executors along
+ * with their cached blocks. If you must use both features, you are advised to set
+ * `spark.dynamicAllocation.cachedExecutorIdleTimeout` to a high value.
+ *
+ * The checkpoint directory set through `SparkContext#setCheckpointDir` is not used.
+ */
+ def localCheckpoint(): this.type = RDDCheckpointData.synchronized {
+ if (conf.getBoolean("spark.dynamicAllocation.enabled", false) &&
+ conf.contains("spark.dynamicAllocation.cachedExecutorIdleTimeout")) {
+ logWarning("Local checkpointing is NOT safe to use with dynamic allocation, " +
+ "which removes executors along with their cached blocks. If you must use both " +
+ "features, you are advised to set `spark.dynamicAllocation.cachedExecutorIdleTimeout` " +
+ "to a high value. E.g. If you plan to use the RDD for 1 hour, set the timeout to " +
+ "at least 1 hour.")
+ }
+
+ // Note: At this point we do not actually know whether the user will call persist() on
+ // this RDD later, so we must explicitly call it here ourselves to ensure the cached
+ // blocks are registered for cleanup later in the SparkContext.
+ //
+ // If, however, the user has already called persist() on this RDD, then we must adapt
+ // the storage level he/she specified to one that is appropriate for local checkpointing
+ // (i.e. uses disk) to guarantee correctness.
+
+ if (storageLevel == StorageLevel.NONE) {
+ persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
+ } else {
+ persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true)
}
+
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning(
+ "RDD was already marked for reliable checkpointing: overriding with local checkpoint.")
+ case _ =>
+ }
+ checkpointData = Some(new LocalRDDCheckpointData(this))
+ this
}
/**
- * Return whether this RDD has been checkpointed or not
+ * Return whether this RDD is marked for checkpointing, either reliably or locally.
*/
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
/**
- * Gets the name of the file to which this RDD was checkpointed
+ * Return whether this RDD is marked for local checkpointing.
+ * Exposed for testing.
*/
- def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
+ private[rdd] def isLocallyCheckpointed: Boolean = {
+ checkpointData match {
+ case Some(_: LocalRDDCheckpointData[T]) => true
+ case _ => false
+ }
+ }
+
+ /**
+ * Gets the name of the directory to which this RDD was checkpointed.
+ * This is not defined if the RDD is checkpointed locally.
+ */
+ def getCheckpointFile: Option[String] = {
+ checkpointData match {
+ case Some(reliable: ReliableRDDCheckpointData[T]) => reliable.getCheckpointDir
+ case _ => None
+ }
+ }
// =======================================================================
// Other internal methods and fields
@@ -1545,7 +1631,7 @@ abstract class RDD[T: ClassTag](
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
- checkpointData.get.doCheckpoint()
+ checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
@@ -1557,7 +1643,7 @@ abstract class RDD[T: ClassTag](
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* created from the checkpoint file, and forget its old dependencies and partitions.
*/
- private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
+ private[spark] def markCheckpointed(): Unit = {
clearDependencies()
partitions_ = null
deps = null // Forget the constructor argument for dependencies too
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 4f954363bed8e..0e43520870c0a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -19,10 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark._
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.Partition
/**
* Enumeration to manage state transitions of an RDD through checkpointing
@@ -39,39 +36,31 @@ private[spark] object CheckpointState extends Enumeration {
* as well as, manages the post-checkpoint state by providing the updated partitions,
* iterator and preferred locations of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
- extends Logging with Serializable {
+private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends Serializable {
import CheckpointState._
// The checkpoint state of the associated RDD.
- private var cpState = Initialized
-
- // The file to which the associated RDD has been checkpointed to
- private var cpFile: Option[String] = None
+ protected var cpState = Initialized
- // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
- // This is defined if and only if `cpState` is `Checkpointed`.
+ // The RDD that contains our checkpointed data
private var cpRDD: Option[CheckpointRDD[T]] = None
// TODO: are we sure we need to use a global lock in the following methods?
- // Is the RDD already checkpointed
+ /**
+ * Return whether the checkpoint data for this RDD is already persisted.
+ */
def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
cpState == Checkpointed
}
- // Get the file to which this RDD was checkpointed to as an Option
- def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
- cpFile
- }
-
/**
- * Materialize this RDD and write its content to a reliable DFS.
+ * Materialize this RDD and persist its content.
* This is called immediately after the first action invoked on this RDD has completed.
*/
- def doCheckpoint(): Unit = {
-
+ final def checkpoint(): Unit = {
// Guard against multiple threads checkpointing the same RDD by
// atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
@@ -82,64 +71,41 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}
- // Create the output path for the checkpoint
- val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
- val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
- if (!fs.mkdirs(path)) {
- throw new SparkException(s"Failed to create checkpoint path $path")
- }
-
- // Save to file, and reload it as an RDD
- val broadcastedConf = rdd.context.broadcast(
- new SerializableConfiguration(rdd.context.hadoopConfiguration))
- val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
- if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
- rdd.context.cleaner.foreach { cleaner =>
- cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
- }
- }
-
- // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
- if (newRDD.partitions.length != rdd.partitions.length) {
- throw new SparkException(
- "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
- "number of partitions than original RDD " + rdd + "(" + rdd.partitions.length + ")")
- }
+ val newRDD = doCheckpoint()
- // Change the dependencies and partitions of the RDD
+ // Update our state and truncate the RDD lineage
RDDCheckpointData.synchronized {
- cpFile = Some(path.toString)
cpRDD = Some(newRDD)
- rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
+ rdd.markCheckpointed()
}
- logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
- }
-
- def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
- cpRDD.get.partitions
}
- def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
- cpRDD
- }
-}
+ /**
+ * Materialize this RDD and persist its content.
+ *
+ * Subclasses should override this method to define custom checkpointing behavior.
+ * @return the checkpoint RDD created in the process.
+ */
+ protected def doCheckpoint(): CheckpointRDD[T]
-private[spark] object RDDCheckpointData {
+ /**
+ * Return the RDD that contains our checkpointed data.
+ * This is only defined if the checkpoint state is `Checkpointed`.
+ */
+ def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized { cpRDD }
- /** Return the path of the directory to which this RDD's checkpoint data is written. */
- def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
- sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
+ /**
+ * Return the partitions of the resulting checkpoint RDD.
+ * For tests only.
+ */
+ def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
+ cpRDD.map(_.partitions).getOrElse { Array.empty }
}
- /** Clean up the files associated with the checkpoint data for this RDD. */
- def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
- rddCheckpointDataPath(sc, rddId).foreach { path =>
- val fs = path.getFileSystem(sc.hadoopConfiguration)
- if (fs.exists(path)) {
- fs.delete(path, true)
- }
- }
- }
}
+
+/**
+ * Global lock for synchronizing checkpoint operations.
+ */
+private[spark] object RDDCheckpointData
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
new file mode 100644
index 0000000000000..35d8b0bfd18c5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.IOException
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * An RDD that reads from checkpoint files previously written to reliable storage.
+ */
+private[spark] class ReliableCheckpointRDD[T: ClassTag](
+ @transient sc: SparkContext,
+ val checkpointPath: String)
+ extends CheckpointRDD[T](sc) {
+
+ @transient private val hadoopConf = sc.hadoopConfiguration
+ @transient private val cpath = new Path(checkpointPath)
+ @transient private val fs = cpath.getFileSystem(hadoopConf)
+ private val broadcastedConf = sc.broadcast(new SerializableConfiguration(hadoopConf))
+
+ // Fail fast if checkpoint directory does not exist
+ require(fs.exists(cpath), s"Checkpoint directory does not exist: $checkpointPath")
+
+ /**
+ * Return the path of the checkpoint directory this RDD reads data from.
+ */
+ override def getCheckpointFile: Option[String] = Some(checkpointPath)
+
+ /**
+ * Return partitions described by the files in the checkpoint directory.
+ *
+ * Since the original RDD may belong to a prior application, there is no way to know a
+ * priori the number of partitions to expect. This method assumes that the original set of
+ * checkpoint files are fully preserved in a reliable storage across application lifespans.
+ */
+ protected override def getPartitions: Array[Partition] = {
+ // listStatus can throw exception if path does not exist.
+ val inputFiles = fs.listStatus(cpath)
+ .map(_.getPath)
+ .filter(_.getName.startsWith("part-"))
+ .sortBy(_.toString)
+ // Fail fast if input files are invalid
+ inputFiles.zipWithIndex.foreach { case (path, i) =>
+ if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) {
+ throw new SparkException(s"Invalid checkpoint file: $path")
+ }
+ }
+ Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i))
+ }
+
+ /**
+ * Return the locations of the checkpoint file associated with the given partition.
+ */
+ protected override def getPreferredLocations(split: Partition): Seq[String] = {
+ val status = fs.getFileStatus(
+ new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)))
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ /**
+ * Read the content of the checkpoint file associated with the given partition.
+ */
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ val file = new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))
+ ReliableCheckpointRDD.readCheckpointFile(file, broadcastedConf, context)
+ }
+
+}
+
+private[spark] object ReliableCheckpointRDD extends Logging {
+
+ /**
+ * Return the checkpoint file name for the given partition.
+ */
+ private def checkpointFileName(partitionIndex: Int): String = {
+ "part-%05d".format(partitionIndex)
+ }
+
+ /**
+ * Write this partition's values to a checkpoint file.
+ */
+ def writeCheckpointFile[T: ClassTag](
+ path: String,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ val env = SparkEnv.get
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(broadcastedConf.value.value)
+
+ val finalOutputName = ReliableCheckpointRDD.checkpointFileName(ctx.partitionId())
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath =
+ new Path(outputDir, s".$finalOutputName-attempt-${ctx.attemptNumber()}")
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
+ }
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = env.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ Utils.tryWithSafeFinally {
+ serializeStream.writeAll(iterator)
+ } {
+ serializeStream.close()
+ }
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ logInfo(s"Deleting tempOutputPath $tempOutputPath")
+ fs.delete(tempOutputPath, false)
+ throw new IOException("Checkpoint failed: failed to save output of task: " +
+ s"${ctx.attemptNumber()} and final output path does not exist: $finalOutputPath")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo(s"Final output path $finalOutputPath already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
+ }
+ }
+ }
+
+ /**
+ * Read the content of the specified checkpoint file.
+ */
+ def readCheckpointFile[T](
+ path: Path,
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ context: TaskContext): Iterator[T] = {
+ val env = SparkEnv.get
+ val fs = path.getFileSystem(broadcastedConf.value.value)
+ val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
+ val fileInputStream = fs.open(path, bufferSize)
+ val serializer = env.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addTaskCompletionListener(context => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
new file mode 100644
index 0000000000000..1df8eef5ff2b9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark._
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * An implementation of checkpointing that writes the RDD data to reliable storage.
+ * This allows drivers to be restarted on failure with previously computed state.
+ */
+private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
+ extends RDDCheckpointData[T](rdd) with Logging {
+
+ // The directory to which the associated RDD has been checkpointed to
+ // This is assumed to be a non-local path that points to some reliable storage
+ private val cpDir: String =
+ ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id)
+ .map(_.toString)
+ .getOrElse { throw new SparkException("Checkpoint dir must be specified.") }
+
+ /**
+ * Return the directory to which this RDD was checkpointed.
+ * If the RDD is not checkpointed yet, return None.
+ */
+ def getCheckpointDir: Option[String] = RDDCheckpointData.synchronized {
+ if (isCheckpointed) {
+ Some(cpDir.toString)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Materialize this RDD and write its content to a reliable DFS.
+ * This is called immediately after the first action invoked on this RDD has completed.
+ */
+ protected override def doCheckpoint(): CheckpointRDD[T] = {
+
+ // Create the output path for the checkpoint
+ val path = new Path(cpDir)
+ val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
+ if (!fs.mkdirs(path)) {
+ throw new SparkException(s"Failed to create checkpoint path $cpDir")
+ }
+
+ // Save to file, and reload it as an RDD
+ val broadcastedConf = rdd.context.broadcast(
+ new SerializableConfiguration(rdd.context.hadoopConfiguration))
+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
+ rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
+ val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
+ if (newRDD.partitions.length != rdd.partitions.length) {
+ throw new SparkException(
+ s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
+ s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
+ }
+
+ // Optionally clean our checkpoint files if the reference is out of scope
+ if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
+ rdd.context.cleaner.foreach { cleaner =>
+ cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
+ }
+ }
+
+ logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
+
+ newRDD
+ }
+
+}
+
+private[spark] object ReliableRDDCheckpointData {
+
+ /** Return the path of the directory to which this RDD's checkpoint data is written. */
+ def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = {
+ sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
+ }
+
+ /** Clean up the files associated with the checkpoint data for this RDD. */
+ def cleanCheckpoint(sc: SparkContext, rddId: Int): Unit = {
+ checkpointPath(sc, rddId).foreach { path =>
+ val fs = path.getFileSystem(sc.hadoopConfiguration)
+ if (fs.exists(path)) {
+ fs.delete(path, true)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
index 35e44cb59c1be..fa3fecc80cb63 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala
@@ -26,16 +26,14 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.{Partition => SparkPartition, _}
-import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils}
private[spark] class SqlNewHadoopPartition(
@@ -60,18 +58,16 @@ private[spark] class SqlNewHadoopPartition(
* and the executor side to the shared Hadoop Configuration.
*
* Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with
- * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be
- * folded into core.
+ * changes based on [[org.apache.spark.rdd.HadoopRDD]].
*/
-private[spark] class SqlNewHadoopRDD[K, V](
+private[spark] class SqlNewHadoopRDD[V: ClassTag](
@transient sc : SparkContext,
broadcastedConf: Broadcast[SerializableConfiguration],
@transient initDriverSideJobFuncOpt: Option[Job => Unit],
initLocalJobFuncOpt: Option[Job => Unit],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
+ inputFormatClass: Class[_ <: InputFormat[Void, V]],
valueClass: Class[V])
- extends RDD[(K, V)](sc, Nil)
+ extends RDD[V](sc, Nil)
with SparkHadoopMapReduceUtil
with Logging {
@@ -120,8 +116,8 @@ private[spark] class SqlNewHadoopRDD[K, V](
override def compute(
theSplit: SparkPartition,
- context: TaskContext): InterruptibleIterator[(K, V)] = {
- val iter = new Iterator[(K, V)] {
+ context: TaskContext): Iterator[V] = {
+ val iter = new Iterator[V] {
val split = theSplit.asInstanceOf[SqlNewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = getConf(isDriverSide = false)
@@ -154,17 +150,20 @@ private[spark] class SqlNewHadoopRDD[K, V](
configurable.setConf(conf)
case _ =>
}
- private var reader = format.createRecordReader(
+ private[this] var reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
- var havePair = false
- var finished = false
- var recordsSinceMetricsUpdate = 0
+
+ private[this] var havePair = false
+ private[this] var finished = false
override def hasNext: Boolean = {
+ if (context.isInterrupted) {
+ throw new TaskKilledException
+ }
if (!finished && !havePair) {
finished = !reader.nextKeyValue
if (finished) {
@@ -178,7 +177,7 @@ private[spark] class SqlNewHadoopRDD[K, V](
!finished
}
- override def next(): (K, V) = {
+ override def next(): V = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
@@ -186,7 +185,7 @@ private[spark] class SqlNewHadoopRDD[K, V](
if (!finished) {
inputMetrics.incRecordsRead(1)
}
- (reader.getCurrentKey, reader.getCurrentValue)
+ reader.getCurrentValue
}
private def close() {
@@ -212,23 +211,14 @@ private[spark] class SqlNewHadoopRDD[K, V](
}
}
} catch {
- case e: Exception => {
- if (!Utils.inShutdown()) {
+ case e: Exception =>
+ if (!ShutdownHookManager.inShutdown()) {
logWarning("Exception in RecordReader.close()", e)
}
- }
}
}
}
- new InterruptibleIterator(context, iter)
- }
-
- /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
- @DeveloperApi
- def mapPartitionsWithInputSplit[U: ClassTag](
- f: (InputSplit, Iterator[(K, V)]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] = {
- new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
+ iter
}
override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index d2b2baef1d8c4..dfcbc51cdf616 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -47,11 +47,11 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
*
- * The lift-cycle will be:
+ * The life-cycle of an endpoint is:
*
- * constructor onStart receive* onStop
+ * constructor -> onStart -> receive* -> onStop
*
- * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use
+ * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use
* [[ThreadSafeRpcEndpoint]]
*
* If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
index 6ae47894598be..7409ac8859991 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
@@ -100,7 +100,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
val future = ask[T](message, timeout)
val result = timeout.awaitResult(future)
if (result == null) {
- throw new SparkException("Actor returned null")
+ throw new SparkException("RpcEndpoint returned null")
}
return result
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index e0edd7d4ae968..11d123eec43ca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi
* Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
*/
@DeveloperApi
-class AccumulableInfo (
+class AccumulableInfo private[spark] (
val id: Long,
val name: String,
val update: Option[String], // represents a partial update within a task
- val value: String) {
+ val value: String,
+ val internal: Boolean) {
override def equals(other: Any): Boolean = other match {
case acc: AccumulableInfo =>
@@ -40,10 +41,10 @@ class AccumulableInfo (
object AccumulableInfo {
def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, update, value)
+ new AccumulableInfo(id, name, update, value, internal = false)
}
def apply(id: Long, name: String, value: String): AccumulableInfo = {
- new AccumulableInfo(id, name, None, value)
+ new AccumulableInfo(id, name, None, value, internal = false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c4fa277c21254..7ab5ccf50adb7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -200,8 +200,8 @@ class DAGScheduler(
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit = {
- eventProcessLoop.post(TaskSetFailed(taskSet, reason))
+ def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = {
+ eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception))
}
private[scheduler]
@@ -677,8 +677,11 @@ class DAGScheduler(
submitWaitingStages()
}
- private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) {
- stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) }
+ private[scheduler] def handleTaskSetFailed(
+ taskSet: TaskSet,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
+ stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) }
submitWaitingStages()
}
@@ -762,7 +765,7 @@ class DAGScheduler(
}
}
} else {
- abortStage(stage, "No active job for stage " + stage.id)
+ abortStage(stage, "No active job for stage " + stage.id, None)
}
}
@@ -773,16 +776,26 @@ class DAGScheduler(
stage.pendingTasks.clear()
// First figure out the indexes of partition ids to compute.
- val partitionsToCompute: Seq[Int] = {
+ val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
stage match {
case stage: ShuffleMapStage =>
- (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty)
+ val allPartitions = 0 until stage.numPartitions
+ val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
+ (allPartitions, filteredPartitions)
case stage: ResultStage =>
val job = stage.resultOfJob.get
- (0 until job.numPartitions).filter(id => !job.finished(id))
+ val allPartitions = 0 until job.numPartitions
+ val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
+ (allPartitions, filteredPartitions)
}
}
+ // Reset internal accumulators only if this stage is not partially submitted
+ // Otherwise, we may override existing accumulator values from some tasks
+ if (allPartitions == partitionsToCompute) {
+ stage.resetInternalAccumulators()
+ }
+
val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull
runningStages += stage
@@ -806,7 +819,7 @@ class DAGScheduler(
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -835,13 +848,13 @@ class DAGScheduler(
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
- abortStage(stage, "Task not serializable: " + e.toString)
+ abortStage(stage, "Task not serializable: " + e.toString, Some(e))
runningStages -= stage
// Abort execution
return
case NonFatal(e) =>
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -852,7 +865,8 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
- new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
+ new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
+ taskBinary, part, locs, stage.internalAccumulators)
}
case stage: ResultStage =>
@@ -861,12 +875,13 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
- new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
+ new ResultTask(stage.id, stage.latestInfo.attemptId,
+ taskBinary, part, locs, id, stage.internalAccumulators)
}
}
} catch {
case NonFatal(e) =>
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -916,9 +931,11 @@ class DAGScheduler(
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
val name = acc.name.get
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}")
+ val value = s"${acc.value}"
+ stage.latestInfo.accumulables(id) =
+ new AccumulableInfo(id, name, None, value, acc.isInternal)
event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}")
+ new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
}
}
} catch {
@@ -1084,7 +1101,8 @@ class DAGScheduler(
}
if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config",
+ None)
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
@@ -1112,7 +1130,7 @@ class DAGScheduler(
case commitDenied: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
- case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
+ case exceptionFailure: ExceptionFailure =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
@@ -1221,7 +1239,10 @@ class DAGScheduler(
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- private[scheduler] def abortStage(failedStage: Stage, reason: String) {
+ private[scheduler] def abortStage(
+ failedStage: Stage,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
if (!stageIdToStage.contains(failedStage.id)) {
// Skip all the actions if the stage has been removed.
return
@@ -1230,7 +1251,7 @@ class DAGScheduler(
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
failedStage.latestInfo.completionTime = Some(clock.getTimeMillis())
for (job <- dependentJobs) {
- failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
+ failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception)
}
if (dependentJobs.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -1238,8 +1259,11 @@ class DAGScheduler(
}
/** Fails a job and all stages that are only used by that job, and cleans up relevant state. */
- private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
- val error = new SparkException(failureReason)
+ private def failJobAndIndependentStages(
+ job: ActiveJob,
+ failureReason: String,
+ exception: Option[Throwable] = None): Unit = {
+ val error = new SparkException(failureReason, exception.getOrElse(null))
var ableToCancelStages = true
val shouldInterruptThread =
@@ -1448,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
dagScheduler.handleTaskCompletion(completion)
- case TaskSetFailed(taskSet, reason) =>
- dagScheduler.handleTaskSetFailed(taskSet, reason)
+ case TaskSetFailed(taskSet, reason, exception) =>
+ dagScheduler.handleTaskSetFailed(taskSet, reason, exception)
case ResubmitFailedStages =>
dagScheduler.resubmitFailedStages()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a213d419cf033..f72a52e85dc15 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler]
-case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
+ extends DAGSchedulerEvent
private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
index 8321037cdc026..5d926377ce86b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -162,7 +162,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
private[spark] object OutputCommitCoordinator {
- // This actor is used only for RPC
+ // This endpoint is used only for RPC
private[spark] class OutputCommitCoordinatorEndpoint(
override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator)
extends RpcEndpoint with Logging {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 9c2606e278c54..c4dc080e2b22b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U](
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
- val outputId: Int)
- extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
+ val outputId: Int,
+ internalAccumulators: Seq[Accumulator[Long]])
+ extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+ with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 14c8c00961487..f478f9982afef 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask(
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
- @transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
+ @transient private var locs: Seq[TaskLocation],
+ internalAccumulators: Seq[Accumulator[Long]])
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+ with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
@@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask(
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
- return writer.stop(success = true).get
+ writer.stop(success = true).get
} catch {
case e: Exception =>
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 40a333a3e06b2..1cf06856ffbc2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -68,6 +68,22 @@ private[spark] abstract class Stage(
val name = callSite.shortForm
val details = callSite.longForm
+ private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+
+ /** Internal accumulators shared across all tasks in this stage. */
+ def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+
+ /**
+ * Re-initialize the internal accumulators associated with this stage.
+ *
+ * This is called every time the stage is submitted, *except* when a subset of tasks
+ * belonging to this stage has already finished. Otherwise, reinitializing the internal
+ * accumulators here again will override partial values from the finished tasks.
+ */
+ def resetInternalAccumulators(): Unit = {
+ _internalAccumulators = InternalAccumulator.create(rdd.sparkContext)
+ }
+
/**
* Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
* here, before any attempts have actually been created, because the DAGScheduler uses this
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1978305cfefbd..9edf9f048f9fd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -47,7 +47,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
- var partitionId: Int) extends Serializable {
+ val partitionId: Int,
+ internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
@@ -68,12 +69,13 @@ private[spark] abstract class Task[T](
metricsSystem: MetricsSystem)
: (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
- stageId = stageId,
- partitionId = partitionId,
- taskAttemptId = taskAttemptId,
- attemptNumber = attemptNumber,
- taskMemoryManager = taskMemoryManager,
- metricsSystem = metricsSystem,
+ stageId,
+ partitionId,
+ taskAttemptId,
+ attemptNumber,
+ taskMemoryManager,
+ metricsSystem,
+ internalAccumulators,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 82455b0426a5d..818b95d67f6be 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -662,7 +662,7 @@ private[spark] class TaskSetManager(
val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
reason.asInstanceOf[TaskFailedReason].toErrorString
- reason match {
+ val failureException: Option[Throwable] = reason match {
case fetchFailed: FetchFailed =>
logWarning(failureReason)
if (!successful(index)) {
@@ -671,6 +671,7 @@ private[spark] class TaskSetManager(
}
// Not adding to failed executors for FetchFailed.
isZombie = true
+ None
case ef: ExceptionFailure =>
taskMetrics = ef.metrics.orNull
@@ -706,12 +707,15 @@ private[spark] class TaskSetManager(
s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " +
s"${ef.className} (${ef.description}) [duplicate $dupCount]")
}
+ ef.exception
case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
logWarning(failureReason)
+ None
case e: TaskEndReason =>
logError("Unknown TaskEndReason: " + e)
+ None
}
// always add to failed executors
failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
@@ -728,16 +732,16 @@ private[spark] class TaskSetManager(
logError("Task %d in stage %s failed %d times; aborting job".format(
index, taskSet.id, maxTaskFailures))
abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:"
- .format(index, taskSet.id, maxTaskFailures, failureReason))
+ .format(index, taskSet.id, maxTaskFailures, failureReason), failureException)
return
}
}
maybeFinishTaskSet()
}
- def abort(message: String): Unit = sched.synchronized {
+ def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.dagScheduler.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message, exception)
isZombie = true
maybeFinishTaskSet()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index bd89160af4ffa..5730a87f960a0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -134,7 +134,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
} else {
logInfo("Registered executor: " + executorRef + " with ID " + executorId)
- context.reply(RegisteredExecutor)
addressToExecutorId(executorRef.address) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
@@ -149,6 +148,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
}
}
+ // Note: some tests expect the reply to come after we put the executor in the map
+ context.reply(RegisteredExecutor)
listenerBus.post(
SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
@@ -421,21 +422,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
logWarning(s"Executor to kill $id does not exist!")
}
+ // If an executor is already pending to be removed, do not kill it again (SPARK-9795)
+ val executorsToKill = knownExecutors.filter { id => !executorsPendingToRemove.contains(id) }
+ executorsPendingToRemove ++= executorsToKill
+
// If we do not wish to replace the executors we kill, sync the target number of executors
// with the cluster manager to avoid allocating new ones. When computing the new target,
// take into account executors that are pending to be added or removed.
if (!replace) {
- doRequestTotalExecutors(numExistingExecutors + numPendingExecutors
- - executorsPendingToRemove.size - knownExecutors.size)
+ doRequestTotalExecutors(
+ numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)
}
- executorsPendingToRemove ++= knownExecutors
- doKillExecutors(knownExecutors)
+ doKillExecutors(executorsToKill)
}
/**
* Kill the given list of executors through the cluster manager.
- * Return whether the kill request is acknowledged.
+ * @return whether the kill request is acknowledged.
*/
protected def doKillExecutors(executorIds: Seq[String]): Boolean = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index 26e72c0bff38d..626a2b7d69abe 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -22,7 +22,7 @@ import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress}
/**
* Grouping of data for an executor used by CoarseGrainedSchedulerBackend.
*
- * @param executorEndpoint The ActorRef representing this executor
+ * @param executorEndpoint The RpcEndpointRef representing this executor
* @param executorAddress The network address of this executor
* @param executorHost The hostname that this executor is running on
* @param freeCores The current number of cores available for work on the executor
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 687ae9620460f..bbe51b4a09a22 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -152,6 +152,34 @@ private[spark] class SparkDeploySchedulerBackend(
super.applicationId
}
+ /**
+ * Request executors from the Master by specifying the total number desired,
+ * including existing pending and running executors.
+ *
+ * @return whether the request is acknowledged.
+ */
+ protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ Option(client) match {
+ case Some(c) => c.requestTotalExecutors(requestedTotal)
+ case None =>
+ logWarning("Attempted to request executors before driver fully initialized.")
+ false
+ }
+ }
+
+ /**
+ * Kill the given list of executors through the Master.
+ * @return whether the kill request is acknowledged.
+ */
+ protected override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ Option(client) match {
+ case Some(c) => c.killExecutors(executorIds)
+ case None =>
+ logWarning("Attempted to kill executors before driver fully initialized.")
+ false
+ }
+ }
+
private def waitForRegistration() = {
registrationBarrier.acquire()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index b7fde0d9b3265..d6e1e9e5bebc2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -26,12 +26,15 @@ import scala.collection.mutable.{HashMap, HashSet}
import com.google.common.collect.HashBiMap
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
-import org.apache.mesos.{Scheduler => MScheduler, _}
+import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver}
+
+import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
-import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -46,7 +49,8 @@ import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
private[spark] class CoarseMesosSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
- master: String)
+ master: String,
+ securityManager: SecurityManager)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with MScheduler
with MesosSchedulerUtils {
@@ -56,12 +60,19 @@ private[spark] class CoarseMesosSchedulerBackend(
// Maximum number of cores to acquire (TODO: we'll need more flexible controls here)
val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt
+ // If shuffle service is enabled, the Spark driver will register with the shuffle service.
+ // This is for cleaning up shuffle files reliably.
+ private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+
// Cores we have acquired with each Mesos task ID
val coresByTaskId = new HashMap[Int, Int]
var totalCoresAcquired = 0
val slaveIdsWithExecutors = new HashSet[String]
+ // Maping from slave Id to hostname
+ private val slaveIdToHost = new HashMap[String, String]
+
val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String]
// How many times tasks on each slave failed
val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int]
@@ -90,6 +101,19 @@ private[spark] class CoarseMesosSchedulerBackend(
private val slaveOfferConstraints =
parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))
+ // A client for talking to the external shuffle service, if it is a
+ private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = {
+ if (shuffleServiceEnabled) {
+ Some(new MesosExternalShuffleClient(
+ SparkTransportConf.fromSparkConf(conf),
+ securityManager,
+ securityManager.isAuthenticationEnabled(),
+ securityManager.isSaslEncryptionEnabled()))
+ } else {
+ None
+ }
+ }
+
var nextMesosTaskId = 0
@volatile var appId: String = _
@@ -170,6 +194,11 @@ private[spark] class CoarseMesosSchedulerBackend(
s" --app-id $appId")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get))
}
+
+ conf.getOption("spark.mesos.uris").map { uris =>
+ setupUris(uris, command)
+ }
+
command.build()
}
@@ -188,6 +217,7 @@ private[spark] class CoarseMesosSchedulerBackend(
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
appId = frameworkId.getValue
+ mesosExternalShuffleClient.foreach(_.init(appId))
logInfo("Registered as framework ID " + appId)
markRegistered()
}
@@ -244,6 +274,7 @@ private[spark] class CoarseMesosSchedulerBackend(
// accept the offer and launch the task
logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
+ slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname
d.launchTasks(
Collections.singleton(offer.getId),
Collections.singleton(taskBuilder.build()), filters)
@@ -261,7 +292,27 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskId = status.getTaskId.getValue.toInt
val state = status.getState
logInfo(s"Mesos task $taskId is now $state")
+ val slaveId: String = status.getSlaveId.getValue
stateLock.synchronized {
+ // If the shuffle service is enabled, have the driver register with each one of the
+ // shuffle services. This allows the shuffle services to clean up state associated with
+ // this application when the driver exits. There is currently not a great way to detect
+ // this through Mesos, since the shuffle services are set up independently.
+ if (TaskState.fromMesos(state).equals(TaskState.RUNNING) &&
+ slaveIdToHost.contains(slaveId) &&
+ shuffleServiceEnabled) {
+ assume(mesosExternalShuffleClient.isDefined,
+ "External shuffle client was not instantiated even though shuffle service is enabled.")
+ // TODO: Remove this and allow the MesosExternalShuffleService to detect
+ // framework termination when new Mesos Framework HTTP API is available.
+ val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337)
+ val hostname = slaveIdToHost.remove(slaveId).get
+ logDebug(s"Connecting to shuffle service on slave $slaveId, " +
+ s"host $hostname, port $externalShufflePort for app ${conf.getAppId}")
+ mesosExternalShuffleClient.get
+ .registerDriverWithShuffleService(hostname, externalShufflePort)
+ }
+
if (TaskState.isFinished(TaskState.fromMesos(state))) {
val slaveId = taskIdToSlaveId(taskId)
slaveIdsWithExecutors -= slaveId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index f078547e71352..64ec2b8e3db15 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -403,6 +403,9 @@ private[spark] class MesosClusterScheduler(
}
builder.setValue(s"$executable $cmdOptions $jar $appArguments")
builder.setEnvironment(envBuilder.build())
+ conf.getOption("spark.mesos.uris").map { uris =>
+ setupUris(uris, builder)
+ }
builder.build()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 3f63ec1c5832f..5c20606d58715 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -133,6 +133,11 @@ private[spark] class MesosSchedulerBackend(
builder.addAllResources(usedCpuResources)
builder.addAllResources(usedMemResources)
+
+ sc.conf.getOption("spark.mesos.uris").map { uris =>
+ setupUris(uris, command)
+ }
+
val executorInfo = builder
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index c04920e4f5873..5b854aa5c2754 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -331,4 +331,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
sc.executorMemory
}
+ def setupUris(uris: String, builder: CommandInfo.Builder): Unit = {
+ uris.split(",").foreach { uri =>
+ builder.addUris(CommandInfo.URI.newBuilder().setValue(uri.trim()))
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 0ff7562e912ca..2eab6aff045eb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -38,7 +38,7 @@ import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.collection.{ExternalList, ExternalListSerializer, CompactBuffer}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
@@ -103,6 +103,7 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
+ kryo.register(classOf[ExternalList[_]], new ExternalListSerializer[Any]())
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index f038b722957b8..8c3a72644c38a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -19,6 +19,9 @@ package org.apache.spark.shuffle
import scala.collection.mutable
+import com.google.common.annotations.VisibleForTesting
+
+import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}
/**
@@ -34,11 +37,19 @@ import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}
* set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
+ *
+ * Use `ShuffleMemoryManager.create()` factory method to create a new instance.
+ *
+ * @param maxMemory total amount of memory available for execution, in bytes.
+ * @param pageSizeBytes number of bytes for each page, by default.
*/
-private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
- private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes
+private[spark]
+class ShuffleMemoryManager protected (
+ val maxMemory: Long,
+ val pageSizeBytes: Long)
+ extends Logging {
- def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
+ private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes
private def currentTaskAttemptId(): Long = {
// In case this is called on the driver, return an invalid task attempt id.
@@ -85,7 +96,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
return toGrant
} else {
logInfo(
- s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
+ s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
@@ -116,17 +127,57 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
+
+ /** Returns the memory consumption, in bytes, for the current task */
+ def getMemoryConsumptionForThisTask(): Long = synchronized {
+ val taskAttemptId = currentTaskAttemptId()
+ taskMemory.getOrElse(taskAttemptId, 0L)
+ }
}
-private object ShuffleMemoryManager {
+
+private[spark] object ShuffleMemoryManager {
+
+ def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = {
+ val maxMemory = ShuffleMemoryManager.getMaxMemory(conf)
+ val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores)
+ new ShuffleMemoryManager(maxMemory, pageSize)
+ }
+
+ def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = {
+ new ShuffleMemoryManager(maxMemory, pageSizeBytes)
+ }
+
+ @VisibleForTesting
+ def createForTesting(maxMemory: Long): ShuffleMemoryManager = {
+ new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024)
+ }
+
/**
* Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
* of the memory pool and a safety factor since collections can sometimes grow bigger than
* the size we target before we estimate their sizes again.
*/
- def getMaxMemory(conf: SparkConf): Long = {
+ private def getMaxMemory(conf: SparkConf): Long = {
val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
+
+ /**
+ * Sets the page size, in bytes.
+ *
+ * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
+ * by looking at the number of cores available to the process, and the total amount of memory,
+ * and then divide it by a factor of safety.
+ */
+ private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = {
+ val minPageSize = 1L * 1024 * 1024 // 1MB
+ val maxPageSize = 64L * minPageSize // 64MB
+ val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
+ val safetyFactor = 8
+ val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor)
+ val default = math.min(maxPageSize, math.max(minPageSize, size))
+ conf.getSizeAsBytes("spark.buffer.pageSize", default)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index de79fa56f017b..0c8f08f0f3b1b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle.hash
-import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
@@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C](
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
- context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
sorter.iterator
case None =>
aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 5f537692a16c5..3f8d26e1d4cab 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -22,7 +22,7 @@ import java.io.{IOException, File}
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ShutdownHookManager, Utils}
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -133,7 +133,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
Utils.getConfiguredLocalDirs(conf).flatMap { rootDir =>
try {
val localDir = Utils.createDirectory(rootDir, "blockmgr")
- Utils.chmod700(localDir)
logInfo(s"Created local directory at $localDir")
Some(localDir)
} catch {
@@ -145,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def addShutdownHook(): AnyRef = {
- Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () =>
+ ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () =>
logInfo("Shutdown hook called")
DiskBlockManager.this.doStop()
}
@@ -155,7 +154,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
private[spark] def stop() {
// Remove the shutdown hook. It causes memory leaks if we leave it around.
try {
- Utils.removeShutdownHook(shutdownHook)
+ ShutdownHookManager.removeShutdownHook(shutdownHook)
} catch {
case e: Exception =>
logError(s"Exception while removing shutdown hook.", e)
@@ -169,7 +168,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
localDirs.foreach { localDir =>
if (localDir.isDirectory() && localDir.exists()) {
try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(localDir)) {
+ Utils.deleteRecursively(localDir)
+ }
} catch {
case e: Exception =>
logError(s"Exception while deleting local spark dir: $localDir", e)
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index b53c86e89a273..22878783fca67 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -27,11 +27,12 @@ import scala.util.control.NonFatal
import com.google.common.io.ByteStreams
import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile}
+import tachyon.conf.TachyonConf
import tachyon.TachyonURI
-import org.apache.spark.{SparkException, SparkConf, Logging}
+import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ShutdownHookManager, Utils}
/**
@@ -60,7 +61,11 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
rootDirs = s"$storeDir/$appFolderName/$executorId"
master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998")
- client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null
+ client = if (master != null && master != "") {
+ TachyonFS.get(new TachyonURI(master), new TachyonConf())
+ } else {
+ null
+ }
// original implementation call System.exit, we change it to run without extblkstore support
if (client == null) {
logError("Failed to connect to the Tachyon as the master address is not configured")
@@ -75,7 +80,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
// in order to avoid having really large inodes at the top level in Tachyon.
tachyonDirs = createTachyonDirs()
subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir))
- tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir))
+ tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir))
}
override def toString: String = {"ExternalBlockStore-Tachyon"}
@@ -235,7 +240,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
logDebug("Shutdown hook called")
tachyonDirs.foreach { tachyonDir =>
try {
- if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) {
+ if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) {
Utils.deleteRecursively(tachyonDir, client)
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index c8356467fab87..779c0ba083596 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -106,7 +106,11 @@ private[spark] object JettyUtils extends Logging {
path: String,
servlet: HttpServlet,
basePath: String): ServletContextHandler = {
- val prefixedPath = attachPrefix(basePath, path)
+ val prefixedPath = if (basePath == "" && path == "/") {
+ path
+ } else {
+ (basePath + path).stripSuffix("/")
+ }
val contextHandler = new ServletContextHandler
val holder = new ServletHolder(servlet)
contextHandler.setContextPath(prefixedPath)
@@ -121,7 +125,7 @@ private[spark] object JettyUtils extends Logging {
beforeRedirect: HttpServletRequest => Unit = x => (),
basePath: String = "",
httpMethods: Set[String] = Set("GET")): ServletContextHandler = {
- val prefixedDestPath = attachPrefix(basePath, destPath)
+ val prefixedDestPath = basePath + destPath
val servlet = new HttpServlet {
override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = {
if (httpMethods.contains("GET")) {
@@ -246,11 +250,6 @@ private[spark] object JettyUtils extends Logging {
val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
}
-
- /** Attach a prefix to the given path, but avoid returning an empty path */
- private def attachPrefix(basePath: String, relativePath: String): String = {
- if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/")
- }
}
private[spark] case class ServerInfo(
diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
index 17d7b39c2d951..6e2375477a688 100644
--- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
@@ -159,9 +159,9 @@ private[ui] trait PagedTable[T] {
// "goButtonJsFuncName"
val formJs =
s"""$$(function(){
- | $$( "#form-task-page" ).submit(function(event) {
- | var page = $$("#form-task-page-no").val()
- | var pageSize = $$("#form-task-page-size").val()
+ | $$( "#form-$tableId-page" ).submit(function(event) {
+ | var page = $$("#form-$tableId-page-no").val()
+ | var pageSize = $$("#form-$tableId-page-size").val()
| pageSize = pageSize ? pageSize: 100;
| if (page != "") {
| ${goButtonJsFuncName}(page, pageSize);
@@ -173,12 +173,14 @@ private[ui] trait PagedTable[T] {
-
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 3788916cf39bb..d8b90568b7b9a 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -64,11 +64,11 @@ private[spark] class SparkUI private (
attachTab(new EnvironmentTab(this))
attachTab(new ExecutorsTab(this))
attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
- attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath))
+ attachHandler(createRedirectHandler("/", "/jobs/", basePath = basePath))
attachHandler(ApiRootResource.getServletHandler(this))
// This should be POST only, but, the YARN AM proxy won't proxy POSTs
attachHandler(createRedirectHandler(
- "/stages/stage/kill", "/stages", stagesTab.handleKillRequest,
+ "/stages/stage/kill", "/stages/", stagesTab.handleKillRequest,
httpMethods = Set("GET", "POST")))
}
initialize()
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index e2d25e36365fa..cb122eaed83d1 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -62,6 +62,13 @@ private[spark] object ToolTips {
"""Time that the executor spent paused for Java garbage collection while the task was
running."""
+ val PEAK_EXECUTION_MEMORY =
+ """Execution memory refers to the memory used by internal data structures created during
+ shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator
+ should be approximately the sum of the peak sizes across all such data structures created
+ in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and
+ external sort."""
+
val JOB_TIMELINE =
"""Shows when jobs started and ended and when executors joined or left. Drag to scroll.
Click Enable Zooming and use mouse wheel to zoom in/out."""
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 718aea7e1dc22..f2da417724104 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -352,7 +352,8 @@ private[spark] object UIUtils extends Logging {
*/
private def showDagViz(graphs: Seq[RDDOperationGraph], forJob: Boolean): Seq[Node] = {
+: getFormattedSizeQuantiles(peakExecutionMemory)
+ }
+
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
@@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay
val schedulerDelayQuantiles = schedulerDelayTitle +:
getFormattedTimeQuantiles(schedulerDelays)
-
- def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] =
- getDistributionQuantiles(data).map(d =>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala
index 17e55f7996bf7..53934ad4ce477 100644
--- a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala
@@ -22,10 +22,10 @@ import java.util.concurrent.atomic.AtomicInteger
/**
* A util used to get a unique generation ID. This is a wrapper around Java's
* AtomicInteger. An example usage is in BlockManager, where each BlockManager
- * instance would start an Akka actor and we use this utility to assign the Akka
- * actors unique names.
+ * instance would start an RpcEndpoint and we use this utility to assign the RpcEndpoints'
+ * unique names.
*/
private[spark] class IdGenerator {
- private var id = new AtomicInteger
+ private val id = new AtomicInteger
def next: Int = id.incrementAndGet
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index c600319d9ddb4..cbc94fd6d54d9 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -790,7 +790,7 @@ private[spark] object JsonProtocol {
val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
map(_.extract[String]).orNull
val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
- ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics)
+ ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
case `executorLostFailure` =>
diff --git a/core/src/main/scala/org/apache/spark/util/ManualClock.scala b/core/src/main/scala/org/apache/spark/util/ManualClock.scala
index 1718554061985..e7a65d74a440e 100644
--- a/core/src/main/scala/org/apache/spark/util/ManualClock.scala
+++ b/core/src/main/scala/org/apache/spark/util/ManualClock.scala
@@ -58,7 +58,7 @@ private[spark] class ManualClock(private var time: Long) extends Clock {
*/
def waitTillTime(targetTime: Long): Long = synchronized {
while (time < targetTime) {
- wait(100)
+ wait(10)
}
getTimeMillis()
}
diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
new file mode 100644
index 0000000000000..61ff9b89ec1c1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.File
+import java.util.PriorityQueue
+
+import scala.util.{Failure, Success, Try}
+import tachyon.client.TachyonFile
+
+import org.apache.hadoop.fs.FileSystem
+import org.apache.spark.Logging
+
+/**
+ * Various utility methods used by Spark.
+ */
+private[spark] object ShutdownHookManager extends Logging {
+ val DEFAULT_SHUTDOWN_PRIORITY = 100
+
+ /**
+ * The shutdown priority of the SparkContext instance. This is lower than the default
+ * priority, so that by default hooks are run before the context is shut down.
+ */
+ val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50
+
+ /**
+ * The shutdown priority of temp directory must be lower than the SparkContext shutdown
+ * priority. Otherwise cleaning the temp directories while Spark jobs are running can
+ * throw undesirable errors at the time of shutdown.
+ */
+ val TEMP_DIR_SHUTDOWN_PRIORITY = 25
+
+ private lazy val shutdownHooks = {
+ val manager = new SparkShutdownHookManager()
+ manager.install()
+ manager
+ }
+
+ private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
+ private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
+
+ // Add a shutdown hook to delete the temp dirs when the JVM exits
+ addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () =>
+ logInfo("Shutdown hook called")
+ shutdownDeletePaths.foreach { dirPath =>
+ try {
+ logInfo("Deleting directory " + dirPath)
+ Utils.deleteRecursively(new File(dirPath))
+ } catch {
+ case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e)
+ }
+ }
+ }
+
+ // Register the path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths += absolutePath
+ }
+ }
+
+ // Register the tachyon path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(tachyonfile: TachyonFile) {
+ val absolutePath = tachyonfile.getPath()
+ shutdownDeleteTachyonPaths.synchronized {
+ shutdownDeleteTachyonPaths += absolutePath
+ }
+ }
+
+ // Remove the path to be deleted via shutdown hook
+ def removeShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.remove(absolutePath)
+ }
+ }
+
+ // Remove the tachyon path to be deleted via shutdown hook
+ def removeShutdownDeleteDir(tachyonfile: TachyonFile) {
+ val absolutePath = tachyonfile.getPath()
+ shutdownDeleteTachyonPaths.synchronized {
+ shutdownDeleteTachyonPaths.remove(absolutePath)
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = {
+ val absolutePath = file.getPath()
+ shutdownDeleteTachyonPaths.synchronized {
+ shutdownDeleteTachyonPaths.contains(absolutePath)
+ }
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in IOException and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.exists { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in Exception and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = {
+ val absolutePath = file.getPath()
+ val retval = shutdownDeleteTachyonPaths.synchronized {
+ shutdownDeleteTachyonPaths.exists { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
+ /**
+ * Detect whether this thread might be executing a shutdown hook. Will always return true if
+ * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g.
+ * if System.exit was just called by a concurrent thread).
+ *
+ * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing
+ * an IllegalStateException.
+ */
+ def inShutdown(): Boolean = {
+ try {
+ val hook = new Thread {
+ override def run() {}
+ }
+ Runtime.getRuntime.addShutdownHook(hook)
+ Runtime.getRuntime.removeShutdownHook(hook)
+ } catch {
+ case ise: IllegalStateException => return true
+ }
+ false
+ }
+
+ /**
+ * Adds a shutdown hook with default priority.
+ *
+ * @param hook The code to run during shutdown.
+ * @return A handle that can be used to unregister the shutdown hook.
+ */
+ def addShutdownHook(hook: () => Unit): AnyRef = {
+ addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook)
+ }
+
+ /**
+ * Adds a shutdown hook with the given priority. Hooks with lower priority values run
+ * first.
+ *
+ * @param hook The code to run during shutdown.
+ * @return A handle that can be used to unregister the shutdown hook.
+ */
+ def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = {
+ shutdownHooks.add(priority, hook)
+ }
+
+ /**
+ * Remove a previously installed shutdown hook.
+ *
+ * @param ref A handle returned by `addShutdownHook`.
+ * @return Whether the hook was removed.
+ */
+ def removeShutdownHook(ref: AnyRef): Boolean = {
+ shutdownHooks.remove(ref)
+ }
+
+}
+
+private [util] class SparkShutdownHookManager {
+
+ private val hooks = new PriorityQueue[SparkShutdownHook]()
+ private var shuttingDown = false
+
+ /**
+ * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not
+ * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for
+ * the best.
+ */
+ def install(): Unit = {
+ val hookTask = new Runnable() {
+ override def run(): Unit = runAll()
+ }
+ Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match {
+ case Success(shmClass) =>
+ val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get()
+ .asInstanceOf[Int]
+ val shm = shmClass.getMethod("get").invoke(null)
+ shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int])
+ .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30))
+
+ case Failure(_) =>
+ Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook"));
+ }
+ }
+
+ def runAll(): Unit = synchronized {
+ shuttingDown = true
+ while (!hooks.isEmpty()) {
+ Try(Utils.logUncaughtExceptions(hooks.poll().run()))
+ }
+ }
+
+ def add(priority: Int, hook: () => Unit): AnyRef = synchronized {
+ checkState()
+ val hookRef = new SparkShutdownHook(priority, hook)
+ hooks.add(hookRef)
+ hookRef
+ }
+
+ def remove(ref: AnyRef): Boolean = synchronized {
+ hooks.remove(ref)
+ }
+
+ private def checkState(): Unit = {
+ if (shuttingDown) {
+ throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.")
+ }
+ }
+
+}
+
+private class SparkShutdownHook(private val priority: Int, hook: () => Unit)
+ extends Comparable[SparkShutdownHook] {
+
+ override def compareTo(other: SparkShutdownHook): Int = {
+ other.priority - priority
+ }
+
+ def run(): Unit = hook()
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
index ad3db1fbb57ed..7248187247330 100644
--- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -33,7 +33,7 @@ private[spark] object SparkUncaughtExceptionHandler
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
+ if (!ShutdownHookManager.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(SparkExitCode.OOM)
} else {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index c4012d0e83f7d..f2abf227dc129 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,7 +21,7 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
-import java.util.{PriorityQueue, Properties, Locale, Random, UUID}
+import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent._
import javax.net.ssl.HttpsURLConnection
@@ -65,21 +65,6 @@ private[spark] object CallSite {
private[spark] object Utils extends Logging {
val random = new Random()
- val DEFAULT_SHUTDOWN_PRIORITY = 100
-
- /**
- * The shutdown priority of the SparkContext instance. This is lower than the default
- * priority, so that by default hooks are run before the context is shut down.
- */
- val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50
-
- /**
- * The shutdown priority of temp directory must be lower than the SparkContext shutdown
- * priority. Otherwise cleaning the temp directories while Spark jobs are running can
- * throw undesirable errors at the time of shutdown.
- */
- val TEMP_DIR_SHUTDOWN_PRIORITY = 25
-
/**
* Define a default value for driver memory here since this value is referenced across the code
* base and nearly all files already use Utils.scala
@@ -90,9 +75,6 @@ private[spark] object Utils extends Logging {
@volatile private var localRootDirs: Array[String] = null
- private val shutdownHooks = new SparkShutdownHookManager()
- shutdownHooks.install()
-
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -205,86 +187,6 @@ private[spark] object Utils extends Logging {
}
}
- private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
- private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
-
- // Add a shutdown hook to delete the temp dirs when the JVM exits
- addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () =>
- logInfo("Shutdown hook called")
- shutdownDeletePaths.foreach { dirPath =>
- try {
- logInfo("Deleting directory " + dirPath)
- Utils.deleteRecursively(new File(dirPath))
- } catch {
- case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e)
- }
- }
- }
-
- // Register the path to be deleted via shutdown hook
- def registerShutdownDeleteDir(file: File) {
- val absolutePath = file.getAbsolutePath()
- shutdownDeletePaths.synchronized {
- shutdownDeletePaths += absolutePath
- }
- }
-
- // Register the tachyon path to be deleted via shutdown hook
- def registerShutdownDeleteDir(tachyonfile: TachyonFile) {
- val absolutePath = tachyonfile.getPath()
- shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths += absolutePath
- }
- }
-
- // Is the path already registered to be deleted via a shutdown hook ?
- def hasShutdownDeleteDir(file: File): Boolean = {
- val absolutePath = file.getAbsolutePath()
- shutdownDeletePaths.synchronized {
- shutdownDeletePaths.contains(absolutePath)
- }
- }
-
- // Is the path already registered to be deleted via a shutdown hook ?
- def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = {
- val absolutePath = file.getPath()
- shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths.contains(absolutePath)
- }
- }
-
- // Note: if file is child of some registered path, while not equal to it, then return true;
- // else false. This is to ensure that two shutdown hooks do not try to delete each others
- // paths - resulting in IOException and incomplete cleanup.
- def hasRootAsShutdownDeleteDir(file: File): Boolean = {
- val absolutePath = file.getAbsolutePath()
- val retval = shutdownDeletePaths.synchronized {
- shutdownDeletePaths.exists { path =>
- !absolutePath.equals(path) && absolutePath.startsWith(path)
- }
- }
- if (retval) {
- logInfo("path = " + file + ", already present as root for deletion.")
- }
- retval
- }
-
- // Note: if file is child of some registered path, while not equal to it, then return true;
- // else false. This is to ensure that two shutdown hooks do not try to delete each others
- // paths - resulting in Exception and incomplete cleanup.
- def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = {
- val absolutePath = file.getPath()
- val retval = shutdownDeleteTachyonPaths.synchronized {
- shutdownDeleteTachyonPaths.exists { path =>
- !absolutePath.equals(path) && absolutePath.startsWith(path)
- }
- }
- if (retval) {
- logInfo("path = " + file + ", already present as root for deletion.")
- }
- retval
- }
-
/**
* JDK equivalent of `chmod 700 file`.
*
@@ -333,7 +235,7 @@ private[spark] object Utils extends Logging {
root: String = System.getProperty("java.io.tmpdir"),
namePrefix: String = "spark"): File = {
val dir = createDirectory(root, namePrefix)
- registerShutdownDeleteDir(dir)
+ ShutdownHookManager.registerShutdownDeleteDir(dir)
dir
}
@@ -973,9 +875,7 @@ private[spark] object Utils extends Logging {
if (savedIOException != null) {
throw savedIOException
}
- shutdownDeletePaths.synchronized {
- shutdownDeletePaths.remove(file.getAbsolutePath)
- }
+ ShutdownHookManager.removeShutdownDeleteDir(file)
}
} finally {
if (!file.delete()) {
@@ -1478,27 +1378,6 @@ private[spark] object Utils extends Logging {
serializer.deserialize[T](serializer.serialize(value))
}
- /**
- * Detect whether this thread might be executing a shutdown hook. Will always return true if
- * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g.
- * if System.exit was just called by a concurrent thread).
- *
- * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing
- * an IllegalStateException.
- */
- def inShutdown(): Boolean = {
- try {
- val hook = new Thread {
- override def run() {}
- }
- Runtime.getRuntime.addShutdownHook(hook)
- Runtime.getRuntime.removeShutdownHook(hook)
- } catch {
- case ise: IllegalStateException => return true
- }
- false
- }
-
private def isSpace(c: Char): Boolean = {
" \t\r\n".indexOf(c) != -1
}
@@ -2221,37 +2100,6 @@ private[spark] object Utils extends Logging {
msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX)
}
- /**
- * Adds a shutdown hook with default priority.
- *
- * @param hook The code to run during shutdown.
- * @return A handle that can be used to unregister the shutdown hook.
- */
- def addShutdownHook(hook: () => Unit): AnyRef = {
- addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook)
- }
-
- /**
- * Adds a shutdown hook with the given priority. Hooks with lower priority values run
- * first.
- *
- * @param hook The code to run during shutdown.
- * @return A handle that can be used to unregister the shutdown hook.
- */
- def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = {
- shutdownHooks.add(priority, hook)
- }
-
- /**
- * Remove a previously installed shutdown hook.
- *
- * @param ref A handle returned by `addShutdownHook`.
- * @return Whether the hook was removed.
- */
- def removeShutdownHook(ref: AnyRef): Boolean = {
- shutdownHooks.remove(ref)
- }
-
/**
* To avoid calling `Utils.getCallSite` for every single RDD we create in the body,
* set a dummy call site that RDDs use instead. This is for performance optimization.
@@ -2286,70 +2134,17 @@ private[spark] object Utils extends Logging {
isInDirectory(parent, child.getParentFile)
}
-}
-
-private [util] class SparkShutdownHookManager {
-
- private val hooks = new PriorityQueue[SparkShutdownHook]()
- private var shuttingDown = false
-
/**
- * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not
- * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for
- * the best.
+ * Return whether dynamic allocation is enabled in the given conf
+ * Dynamic allocation and explicitly setting the number of executors are inherently
+ * incompatible. In environments where dynamic allocation is turned on by default,
+ * the latter should override the former (SPARK-9092).
*/
- def install(): Unit = {
- val hookTask = new Runnable() {
- override def run(): Unit = runAll()
- }
- Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match {
- case Success(shmClass) =>
- val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get()
- .asInstanceOf[Int]
- val shm = shmClass.getMethod("get").invoke(null)
- shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int])
- .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30))
-
- case Failure(_) =>
- Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook"));
- }
+ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = {
+ conf.contains("spark.dynamicAllocation.enabled") &&
+ conf.getInt("spark.executor.instances", 0) == 0
}
- def runAll(): Unit = synchronized {
- shuttingDown = true
- while (!hooks.isEmpty()) {
- Try(Utils.logUncaughtExceptions(hooks.poll().run()))
- }
- }
-
- def add(priority: Int, hook: () => Unit): AnyRef = synchronized {
- checkState()
- val hookRef = new SparkShutdownHook(priority, hook)
- hooks.add(hookRef)
- hookRef
- }
-
- def remove(ref: AnyRef): Boolean = synchronized {
- hooks.remove(ref)
- }
-
- private def checkState(): Unit = {
- if (shuttingDown) {
- throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.")
- }
- }
-
-}
-
-private class SparkShutdownHook(private val priority: Int, hook: () => Unit)
- extends Comparable[SparkShutdownHook] {
-
- override def compareTo(other: SparkShutdownHook): Int = {
- other.priority - priority
- }
-
- def run(): Unit = hook()
-
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala
new file mode 100644
index 0000000000000..e0fb9e131de33
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util.cleanup
+
+import java.lang.ref.{ReferenceQueue, WeakReference}
+
+/**
+ * Classes that represent cleaning tasks.
+ */
+private[spark] sealed trait CleanupTask
+private[spark] case class CleanRDD(rddId: Int) extends CleanupTask
+private[spark] case class CleanShuffle(shuffleId: Int) extends CleanupTask
+private[spark] case class CleanBroadcast(broadcastId: Long) extends CleanupTask
+private[spark] case class CleanAccum(accId: Long) extends CleanupTask
+private[spark] case class CleanCheckpoint(rddId: Int) extends CleanupTask
+private[spark] case class CleanExternalList(pathsToClean: Iterable[String]) extends CleanupTask
+
+/**
+ * A WeakReference associated with a CleanupTask.
+ *
+ * When the referent object becomes only weakly reachable, the corresponding
+ * CleanupTaskWeakReference is automatically added to the given reference queue.
+ */
+private[spark] class CleanupTaskWeakReference(
+ val task: CleanupTask,
+ referent: AnyRef,
+ referenceQueue: ReferenceQueue[AnyRef])
+ extends WeakReference(referent, referenceQueue)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index d166037351c31..3284113809dca 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -24,14 +24,11 @@ import scala.collection.BufferedIterator
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.ByteStreams
-
-import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.{DeserializationStream, Serializer}
-import org.apache.spark.storage.{BlockId, BlockManager}
+import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId, BlockManager}
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
-import org.apache.spark.executor.ShuffleWriteMetrics
/**
* :: DeveloperApi ::
@@ -69,36 +66,16 @@ class ExternalAppendOnlyMap[K, V, C](
extends Iterable[(K, C)]
with Serializable
with Logging
- with Spillable[SizeTracker] {
+ with SpillableCollection[(K, C), SizeTrackingAppendOnlyMap[K, C]] {
private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
- private val sparkConf = SparkEnv.get.conf
- private val diskBlockManager = blockManager.diskBlockManager
-
- /**
- * Size of object batches when reading/writing from serializers.
- *
- * Objects are written in batches, with each batch using its own serialization stream. This
- * cuts down on the size of reference-tracking maps constructed when deserializing a stream.
- *
- * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
- * grow internal data structures by growing + copying every time the number of objects doubles.
- */
- private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
-
- // Number of bytes spilled in total
- private var _diskBytesSpilled = 0L
-
- // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
- private val fileBufferSize =
- sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
- // Write metrics for current spill
- private var curWriteMetrics: ShuffleWriteMetrics = _
+ // Peak size of the in-memory map observed so far, in bytes
+ private var _peakMemoryUsedBytes: Long = 0L
+ def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
private val keyComparator = new HashComparator[K]
- private val ser = serializer.newInstance()
/**
* Insert the given key and value into the map.
@@ -126,7 +103,11 @@ class ExternalAppendOnlyMap[K, V, C](
while (entries.hasNext) {
curEntry = entries.next()
- if (maybeSpill(currentMap, currentMap.estimateSize())) {
+ val estimatedSize = currentMap.estimateSize()
+ if (estimatedSize > _peakMemoryUsedBytes) {
+ _peakMemoryUsedBytes = estimatedSize
+ }
+ if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
@@ -147,68 +128,6 @@ class ExternalAppendOnlyMap[K, V, C](
insertAll(entries.iterator)
}
- /**
- * Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
- */
- override protected[this] def spill(collection: SizeTracker): Unit = {
- val (blockId, file) = diskBlockManager.createTempLocalBlock()
- curWriteMetrics = new ShuffleWriteMetrics()
- var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
- var objectsWritten = 0
-
- // List of batch sizes (bytes) in the order they are written to disk
- val batchSizes = new ArrayBuffer[Long]
-
- // Flush the disk writer's contents to disk, and update relevant variables
- def flush(): Unit = {
- val w = writer
- writer = null
- w.commitAndClose()
- _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
- batchSizes.append(curWriteMetrics.shuffleBytesWritten)
- objectsWritten = 0
- }
-
- var success = false
- try {
- val it = currentMap.destructiveSortedIterator(keyComparator)
- while (it.hasNext) {
- val kv = it.next()
- writer.write(kv._1, kv._2)
- objectsWritten += 1
-
- if (objectsWritten == serializerBatchSize) {
- flush()
- curWriteMetrics = new ShuffleWriteMetrics()
- writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
- }
- }
- if (objectsWritten > 0) {
- flush()
- } else if (writer != null) {
- val w = writer
- writer = null
- w.revertPartialWritesAndClose()
- }
- success = true
- } finally {
- if (!success) {
- // This code path only happens if an exception was thrown above before we set success;
- // close our stuff and let the exception be thrown further
- if (writer != null) {
- writer.revertPartialWritesAndClose()
- }
- if (file.exists()) {
- file.delete()
- }
- }
- }
-
- spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
- }
-
- def diskBytesSpilled: Long = _diskBytesSpilled
-
/**
* Return an iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
@@ -374,128 +293,38 @@ class ExternalAppendOnlyMap[K, V, C](
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
- extends Iterator[(K, C)]
+ extends DiskIterator(file, blockId, batchSizes)
{
- private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
- assert(file.length() == batchOffsets.last,
- "File length is not equal to the last batch offset:\n" +
- s" file length = ${file.length}\n" +
- s" last batch offset = ${batchOffsets.last}\n" +
- s" all batch offsets = ${batchOffsets.mkString(",")}"
- )
-
- private var batchIndex = 0 // Which batch we're in
- private var fileStream: FileInputStream = null
-
- // An intermediate stream that reads from exactly one batch
- // This guards against pre-fetching and other arbitrary behavior of higher level streams
- private var deserializeStream = nextBatchStream()
- private var nextItem: (K, C) = null
- private var objectsRead = 0
-
- /**
- * Construct a stream that reads only from the next batch.
- */
- private def nextBatchStream(): DeserializationStream = {
- // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
- // we're still in a valid batch.
- if (batchIndex < batchOffsets.length - 1) {
- if (deserializeStream != null) {
- deserializeStream.close()
- fileStream.close()
- deserializeStream = null
- fileStream = null
- }
-
- val start = batchOffsets(batchIndex)
- fileStream = new FileInputStream(file)
- fileStream.getChannel.position(start)
- batchIndex += 1
-
- val end = batchOffsets(batchIndex)
-
- assert(end >= start, "start = " + start + ", end = " + end +
- ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
-
- val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
- val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
- ser.deserializeStream(compressedStream)
- } else {
- // No more batches left
- cleanup()
- null
- }
+ override protected def readNextItemFromStream(
+ deserializeStream: DeserializationStream): (K, C) = {
+ val k = deserializeStream.readKey().asInstanceOf[K]
+ val v = deserializeStream.readValue().asInstanceOf[C]
+ (k, v)
}
- /**
- * Return the next (K, C) pair from the deserialization stream.
- *
- * If the current batch is drained, construct a stream for the next batch and read from it.
- * If no more pairs are left, return null.
- */
- private def readNextItem(): (K, C) = {
- try {
- val k = deserializeStream.readKey().asInstanceOf[K]
- val c = deserializeStream.readValue().asInstanceOf[C]
- val item = (k, c)
- objectsRead += 1
- if (objectsRead == serializerBatchSize) {
- objectsRead = 0
- deserializeStream = nextBatchStream()
- }
- item
- } catch {
- case e: EOFException =>
- cleanup()
- null
- }
- }
+ override protected def shouldCleanupFileAfterOneIteration(): Boolean = true
+ }
- override def hasNext: Boolean = {
- if (nextItem == null) {
- if (deserializeStream == null) {
- return false
- }
- nextItem = readNextItem()
- }
- nextItem != null
- }
- override def next(): (K, C) = {
- val item = if (nextItem == null) readNextItem() else nextItem
- if (item == null) {
- throw new NoSuchElementException
- }
- nextItem = null
- item
- }
+ /** Convenience function to hash the given (K, C) pair by the key. */
+ private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)
- private def cleanup() {
- batchIndex = batchOffsets.length // Prevent reading any other batch
- val ds = deserializeStream
- if (ds != null) {
- ds.close()
- deserializeStream = null
- }
- if (fileStream != null) {
- fileStream.close()
- fileStream = null
- }
- if (file.exists()) {
- file.delete()
- }
- }
+ override protected def getIteratorForCurrentSpillable(): Iterator[(K, C)] = {
+ currentMap.destructiveSortedIterator(keyComparator)
+ }
- val context = TaskContext.get()
- // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in
- // a TaskContext.
- if (context != null) {
- context.addTaskCompletionListener(context => cleanup())
- }
+ override protected def writeNextObject(
+ c: (K, C),
+ writer: DiskBlockObjectWriter): Unit = {
+ writer.write(c._1, c._2)
}
- /** Convenience function to hash the given (K, C) pair by the key. */
- private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)
+ override protected def recordNextSpilledPart(
+ file: File,
+ blockId: BlockId,
+ batchSizes: ArrayBuffer[Long]): Unit = {
+ spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
+ }
}
private[spark] object ExternalAppendOnlyMap {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala
new file mode 100644
index 0000000000000..f0e4fcff81420
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util.collection
+
+import java.io._
+
+import org.apache.spark.util.TaskCompletionListener
+import org.apache.spark.{TaskContext, ExecutorCleaner, SparkEnv}
+
+import scala.reflect.ClassTag
+import scala.collection.generic.Growable
+import scala.collection.mutable.ArrayBuffer
+
+import com.esotericsoftware.kryo.io.{Output, Input}
+import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer}
+
+import org.apache.spark.util.collection.ExternalList._
+import org.apache.spark.serializer.DeserializationStream
+import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId}
+
+
+/**
+ * List that can spill some of its contents to disk if its contents cannot be held in memory.
+ * Implementation is based heavily on `org.apache.spark.util.collection.ExternalAppendOnlyMap}`
+ */
+@SerialVersionUID(1L)
+private[spark] class ExternalList[T](implicit var tag: ClassTag[T])
+ extends Growable[T]
+ with Iterable[T]
+ with SpillableCollection[T, SizeTrackingCompactBuffer[T]]
+ with Serializable {
+
+ // Var to allow rebuilding it during Java serialization
+ private var spilledLists = new ArrayBuffer[DiskListIterable]
+ private var currentInMemoryList = new SizeTrackingCompactBuffer[T]()
+ private var numItems = 0
+
+ // We don't know up front what files will need to be cleaned up from this list.
+ // So check after the task is completed, after which this ExternalList will be
+ // completely built.
+ private var context = TaskContext.get
+ if (context != null) {
+ context.addTaskCompletionListener(new ScheduleCleanExternalList(this))
+ }
+
+ override def size(): Int = numItems
+
+ override def +=(value: T): this.type = {
+ currentInMemoryList += value
+ if (maybeSpill(currentInMemoryList, currentInMemoryList.estimateSize())) {
+ currentInMemoryList = new SizeTrackingCompactBuffer
+ }
+ numItems += 1
+ this
+ }
+
+ override def clear(): Unit = {
+ spilledLists.foreach(_.deleteBackingFile())
+ spilledLists.clear()
+ currentInMemoryList = new SizeTrackingCompactBuffer[T]()
+ }
+
+ def getBackingFileLocations(): Iterable[String] = {
+ val locations = new ArrayBuffer[String]
+ for (diskList <- spilledLists) {
+ locations.append(diskList.backingFilePath())
+ }
+ return locations
+ }
+
+ def registerForCleanup(): Unit = {
+ if (spilledLists.size > 0) {
+ executorCleaner.registerExternalListForCleanup(this)
+ }
+ }
+
+ override def iterator: Iterator[T] = {
+ val myIt = currentInMemoryList.iterator
+ val allIts = spilledLists.map(_.iterator) ++ Seq(myIt)
+ allIts.foldLeft(Iterator[T]())(_ ++ _)
+ }
+
+ private class DiskListIterable(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
+ extends Iterable[T] {
+ override def iterator: Iterator[T] = {
+ new DiskListIterator(file, blockId, batchSizes)
+ }
+ def deleteBackingFile(): Unit = {
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+ def backingFilePath(): String = file.getAbsolutePath()
+ }
+
+ private class DiskListIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
+ extends DiskIterator(file, blockId, batchSizes) {
+ override protected def readNextItemFromStream(deserializeStream: DeserializationStream): T = {
+ deserializeStream.readKey[Int]()
+ deserializeStream.readValue[T]()
+ }
+
+ // Need to be able to iterate multiple times, so don't clean up the file every time
+ override protected def shouldCleanupFileAfterOneIteration(): Boolean = false
+ }
+
+ @throws(classOf[IOException])
+ private def writeObject(stream: ObjectOutputStream): Unit = {
+ stream.writeObject(tag)
+ stream.writeInt(this.size)
+ val it = this.iterator
+ while (it.hasNext) {
+ stream.writeObject(it.next)
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(stream: ObjectInputStream): Unit = {
+ tag = stream.readObject().asInstanceOf[ClassTag[T]]
+ val listSize = stream.readInt()
+ spilledLists = new ArrayBuffer[DiskListIterable]
+ currentInMemoryList = new SizeTrackingCompactBuffer[T]
+ for(i <- 0L until listSize) {
+ val newItem = stream.readObject().asInstanceOf[T]
+ this.+=(newItem)
+ }
+ // Upon serialization, the context might have changed. So we can't just hold a single context,
+ // but we must retrieving the current context every time.
+ // Notice that in Kryo serialization this object is constructed from scratch
+ // and thus will look for the current TaskContext that way.
+ context = TaskContext.get()
+ if (context != null) {
+ context.addTaskCompletionListener(new ScheduleCleanExternalList(this))
+ }
+ }
+
+ override protected def getIteratorForCurrentSpillable(): Iterator[T] = {
+ currentInMemoryList.iterator
+ }
+
+ override protected def recordNextSpilledPart(
+ file: File,
+ blockId: BlockId,
+ batchSizes: ArrayBuffer[Long]): Unit = {
+ spilledLists += new DiskListIterable(file, blockId, batchSizes)
+ }
+ override protected def writeNextObject(c: T, writer: DiskBlockObjectWriter): Unit = {
+ writer.write(0, c)
+ }
+}
+
+/**
+ * Companion object for constants and singleton-references that we don't want to lose when
+ * Java-serializing
+ */
+private[spark] object ExternalList {
+
+ private class ScheduleCleanExternalList(private var list: ExternalList[_])
+ extends TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = {
+ if (list != null) {
+ executorCleaner.registerExternalListForCleanup(list)
+ // Release reference to allow GC to clean it up
+ list = null
+ }
+ }
+ }
+
+ def apply[T: ClassTag](): ExternalList[T] = new ExternalList[T]
+
+ def apply[T: ClassTag](value: T): ExternalList[T] = {
+ val buf = new ExternalList[T]
+ buf += value
+ buf
+ }
+
+ private val executorCleaner: ExecutorCleaner = SparkEnv.get.executorCleaner
+}
+
+private[spark] class ExternalListSerializer[T: ClassTag] extends KSerializer[ExternalList[T]] {
+ override def write(kryo: Kryo, output: Output, list: ExternalList[T]): Unit = {
+ output.writeInt(list.size)
+ val it = list.iterator
+ while (it.hasNext) {
+ kryo.writeClassAndObject(output, it.next())
+ }
+ }
+
+ override def read(kryo: Kryo, input: Input, clazz: Class[ExternalList[T]]): ExternalList[T] = {
+ val listToRead = new ExternalList[T]
+ val listSize = input.readInt()
+ for (i <- 0L until listSize) {
+ val newItem = kryo.readClassAndObject(input).asInstanceOf[T]
+ listToRead += newItem
+ }
+ listToRead
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index ba7ec834d622d..19287edbaf166 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C](
private var _diskBytesSpilled = 0L
def diskBytesSpilled: Long = _diskBytesSpilled
+ // Peak size of the in-memory data structure observed so far, in bytes
+ private var _peakMemoryUsedBytes: Long = 0L
+ def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C](
return
}
+ var estimatedSize = 0L
if (usingMap) {
- if (maybeSpill(map, map.estimateSize())) {
+ estimatedSize = map.estimateSize()
+ if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
- if (maybeSpill(buffer, buffer.estimateSize())) {
+ estimatedSize = buffer.estimateSize()
+ if (maybeSpill(buffer, estimatedSize)) {
buffer = newBuffer()
}
}
+
+ if (estimatedSize > _peakMemoryUsedBytes) {
+ _peakMemoryUsedBytes = estimatedSize
+ }
}
/**
@@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
- context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
lengths
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala
new file mode 100644
index 0000000000000..d923e9a9e0bd1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+/**
+ * CompactBuffer that keeps track of its size via SizeTracker.
+ */
+private[spark] class SizeTrackingCompactBuffer[T: ClassTag] extends CompactBuffer[T]
+ with SizeTracker {
+
+ override def +=(t: T): SizeTrackingCompactBuffer[T] = {
+ super.+=(t)
+ super.afterUpdate()
+ this
+ }
+
+ override def ++=(t: TraversableOnce[T]): SizeTrackingCompactBuffer[T] = {
+ super.++=(t)
+ super.afterUpdate()
+ this
+ }
+}
+
+private[spark] object SizeTrackingCompactBuffer {
+ def apply[T: ClassTag](): SizeTrackingCompactBuffer[T] = new SizeTrackingCompactBuffer[T]
+
+ def apply[T: ClassTag](value: T): SizeTrackingCompactBuffer[T] = {
+ val buf = new SizeTrackingCompactBuffer[T]
+ buf += value
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index 747ecf075a397..a710d618f3d23 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util.collection
import org.apache.spark.Logging
import org.apache.spark.SparkEnv
+import org.apache.spark.util.collection.Spillable._
/**
* Spills contents of an in-memory collection to disk when the memory threshold
@@ -39,14 +40,6 @@ private[spark] trait Spillable[C] extends Logging {
// It's used for checking spilling frequency
protected def addElementsRead(): Unit = { _elementsRead += 1 }
- // Memory manager that can be used to acquire/release memory
- private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
-
- // Initial threshold for the size of a collection before we start tracking its memory usage
- // Exposed for testing
- private[this] val initialMemoryThreshold: Long =
- SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
-
// Threshold for this collection's size in bytes before we start tracking its memory usage
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
private[this] var myMemoryThreshold = initialMemoryThreshold
@@ -117,4 +110,15 @@ private[spark] trait Spillable[C] extends Logging {
.format(threadId, org.apache.spark.util.Utils.bytesToString(size),
_spillCount, if (_spillCount > 1) "s" else ""))
}
+
+}
+
+private object Spillable {
+ // Memory manager that can be used to acquire/release memory
+ protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
+
+ // Initial threshold for the size of a collection before we start tracking its memory usage
+ // Exposed for testing
+ protected val initialMemoryThreshold: Long =
+ SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala
new file mode 100644
index 0000000000000..c4d0f46bb0bd7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.{EOFException, BufferedInputStream, FileInputStream, File}
+
+import scala.collection.mutable.ArrayBuffer
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.serializer.{DeserializationStream, Serializer}
+import org.apache.spark.storage.{DiskBlockManager, BlockId, DiskBlockObjectWriter, BlockManager}
+import org.apache.spark.util.collection.SpillableCollection._
+
+/**
+ *
+ * Collection that can spill to disk. Takes type parameters T, the iterable type, and
+ * C, the type of the elements returned by T's iterator.
+ */
+private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[T] {
+ // Write metrics for current spill
+ private var curWriteMetrics: ShuffleWriteMetrics = _
+ // Number of bytes spilled in total
+ protected var _diskBytesSpilled = 0L
+ private lazy val ser = serializer.newInstance()
+
+ def diskBytesSpilled: Long = _diskBytesSpilled
+
+ override protected final def spill(collection: T): Unit = {
+ val (blockId, file) = diskBlockManager.createTempLocalBlock()
+ curWriteMetrics = new ShuffleWriteMetrics()
+ var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
+ var objectsWritten = 0
+
+ // List of batch sizes (bytes) in the order they are written to disk
+ val batchSizes = new ArrayBuffer[Long]
+
+ // Flush the disk writer's contents to disk, and update relevant variables
+ def flush(): Unit = {
+ val w = writer
+ writer = null
+ w.commitAndClose()
+ _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
+ batchSizes.append(curWriteMetrics.shuffleBytesWritten)
+ objectsWritten = 0
+ }
+
+ var success = false
+ try {
+ val it = getIteratorForCurrentSpillable()
+ while (it.hasNext) {
+ val kv = it.next()
+ writeNextObject(kv, writer)
+ objectsWritten += 1
+
+ if (objectsWritten == serializerBatchSize) {
+ flush()
+ curWriteMetrics = new ShuffleWriteMetrics()
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
+ }
+ }
+ if (objectsWritten > 0) {
+ flush()
+ } else if (writer != null) {
+ val w = writer
+ writer = null
+ w.revertPartialWritesAndClose()
+ }
+ success = true
+ } finally {
+ if (!success) {
+ // This code path only happens if an exception was thrown above before we set success;
+ // close our stuff and let the exception be thrown further
+ if (writer != null) {
+ writer.revertPartialWritesAndClose()
+ }
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+ }
+
+ recordNextSpilledPart(file, blockId, batchSizes)
+ }
+
+
+ protected def getIteratorForCurrentSpillable(): Iterator[C]
+ protected def writeNextObject(c: C, writer: DiskBlockObjectWriter): Unit
+ protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
+
+ /**
+ * Iterator backed by elements from batches on disk.
+ */
+ protected abstract class DiskIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
+ extends Iterator[C] {
+ private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
+ assert(file.length() == batchOffsets.last,
+ "File length is not equal to the last batch offset:\n" +
+ s" file length = ${file.length}\n" +
+ s" last batch offset = ${batchOffsets.last}\n" +
+ s" all batch offsets = ${batchOffsets.mkString(",")}"
+ )
+
+ private var batchIndex = 0 // Which batch we're in
+ private var fileStream: FileInputStream = null
+
+ // An intermediate stream that reads from exactly one batch
+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
+ private var deserializeStream = nextBatchStream()
+ private var nextItem: Option[C] = None
+ private var objectsRead = 0
+
+ /**
+ * Construct a stream that reads only from the next batch.
+ */
+ protected def nextBatchStream(): DeserializationStream = {
+ // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
+ // we're still in a valid batch.
+ if (batchIndex < batchOffsets.length - 1) {
+ if (deserializeStream != null) {
+ deserializeStream.close()
+ fileStream.close()
+ deserializeStream = null
+ fileStream = null
+ }
+
+ val start = batchOffsets(batchIndex)
+ fileStream = new FileInputStream(file)
+ fileStream.getChannel.position(start)
+ batchIndex += 1
+
+ val end = batchOffsets(batchIndex)
+
+ assert(end >= start, "start = " + start + ", end = " + end +
+ ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
+
+ val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
+ val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
+ ser.deserializeStream(compressedStream)
+ } else {
+ // No more batches left
+ cleanup()
+ null
+ }
+ }
+
+ /**
+ * Return the next item from the deserialization stream.
+ *
+ * If the current batch is drained, construct a stream for the next batch and read from it.
+ * If no more items are left, return null.
+ */
+ protected def readNextItem(): Option[C] = {
+ try {
+ val item = readNextItemFromStream(deserializeStream)
+ objectsRead += 1
+ if (objectsRead == serializerBatchSize) {
+ objectsRead = 0
+ deserializeStream = nextBatchStream()
+ }
+ Some(item)
+ } catch {
+ case e: EOFException =>
+ cleanup()
+ None
+ }
+ }
+
+ private def cleanup() {
+ batchIndex = batchOffsets.length // Prevent reading any other batch
+ val ds = deserializeStream
+ deserializeStream = null
+ if (ds != null) {
+ ds.close()
+ }
+ val fs = fileStream
+ fileStream = null
+ if (fs != null) {
+ fs.close()
+ }
+ if (shouldCleanupFileAfterOneIteration()) {
+ file.delete()
+ }
+ }
+
+ override def hasNext(): Boolean = {
+ if (!nextItem.isDefined) {
+ if (deserializeStream == null) {
+ return false
+ }
+ nextItem = readNextItem()
+ }
+ nextItem.isDefined
+ }
+
+ override def next(): C = {
+ if (!hasNext()) {
+ throw new NoSuchElementException()
+ }
+ val nextValue = nextItem.get
+ nextItem = None
+ nextValue
+ }
+
+ protected def readNextItemFromStream(deserializeStream: DeserializationStream): C
+ protected def shouldCleanupFileAfterOneIteration(): Boolean
+ }
+}
+
+private object SpillableCollection {
+ private def sparkConf(): SparkConf = SparkEnv.get.conf
+ private def blockManager(): BlockManager = SparkEnv.get.blockManager
+ private def diskBlockManager(): DiskBlockManager = blockManager.diskBlockManager
+ private def fileBufferSize(): Int =
+ // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+ sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
+ /**
+ * Size of object batches when reading/writing from serializers.
+ *
+ * Objects are written in batches, with each batch using its own serialization stream. This
+ * cuts down on the size of reference-tracking maps constructed when deserializing a stream.
+ *
+ * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
+ * grow internal data structures by growing + copying every time the number of objects doubles.
+ */
+ private def serializerBatchSize(): Long =
+ sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
+
+ private def serializer(): Serializer = SparkEnv.get.serializer
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 786b97ad7b9ec..c156b03cdb7c4 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
* A sampler for sampling with replacement, based on values drawn from Poisson distribution.
*
* @param fraction the sampling fraction (with replacement)
+ * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low.
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
+class PoissonSampler[T: ClassTag](
+ fraction: Double,
+ useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
+
+ def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true)
/** Epsilon slop to avoid failure from floating point jitter. */
require(
@@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T]
override def sample(items: Iterator[T]): Iterator[T] = {
if (fraction <= 0.0) {
Iterator.empty
- } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
- new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
+ } else if (useGapSamplingIfPossible &&
+ fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+ new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
} else {
- items.flatMap { item => {
+ items.flatMap { item =>
val count = rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
- }}
+ }
}
}
- override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction)
+ override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible)
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e948ca33471a4..ffe4b4baffb2a 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -51,7 +51,6 @@
import org.apache.spark.api.java.*;
import org.apache.spark.api.java.function.*;
-import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.partial.BoundedDouble;
import org.apache.spark.partial.PartialResult;
@@ -1011,7 +1010,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
+ TaskContext context = TaskContext$.MODULE$.empty();
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
index db9e82759090a..934b7e03050b6 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -32,8 +32,8 @@ public class PackedRecordPointerSuite {
public void heap() {
final TaskMemoryManager memoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
- final MemoryBlock page0 = memoryManager.allocatePage(100);
- final MemoryBlock page1 = memoryManager.allocatePage(100);
+ final MemoryBlock page0 = memoryManager.allocatePage(128);
+ final MemoryBlock page1 = memoryManager.allocatePage(128);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
page1.getBaseOffset() + 42);
PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -50,8 +50,8 @@ public void heap() {
public void offHeap() {
final TaskMemoryManager memoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
- final MemoryBlock page0 = memoryManager.allocatePage(100);
- final MemoryBlock page1 = memoryManager.allocatePage(100);
+ final MemoryBlock page0 = memoryManager.allocatePage(128);
+ final MemoryBlock page1 = memoryManager.allocatePage(128);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
page1.getBaseOffset() + 42);
PackedRecordPointer packedPointer = new PackedRecordPointer();
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
index 8fa72597db24d..40fefe2c9d140 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
@@ -24,7 +24,7 @@
import org.junit.Test;
import org.apache.spark.HashPartitioner;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
@@ -34,11 +34,7 @@ public class UnsafeShuffleInMemorySorterSuite {
private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
final byte[] strBytes = new byte[strLength];
- PlatformDependent.copyMemory(
- baseObject,
- baseOffset,
- strBytes,
- PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
+ Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
return new String(strBytes);
}
@@ -74,14 +70,10 @@ public void testBasicSorting() throws Exception {
for (String str : dataToSort) {
final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
final byte[] strBytes = str.getBytes("utf-8");
- PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+ Platform.putInt(baseObject, position, strBytes.length);
position += 4;
- PlatformDependent.copyMemory(
- strBytes,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- baseObject,
- position,
- strBytes.length);
+ Platform.copyMemory(
+ strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
position += strBytes.length;
sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
}
@@ -98,7 +90,7 @@ public void testBasicSorting() throws Exception {
Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
partitionId >= prevPartitionId);
final long recordAddress = iter.packedRecordPointer.getRecordPointer();
- final int recordLength = PlatformDependent.UNSAFE.getInt(
+ final int recordLength = Platform.getInt(
memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
final String str = getStringFromDataPage(
memoryManager.getPage(recordAddress),
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 04fc09b323dbb..94650be536b5f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -115,6 +115,7 @@ public void setUp() throws IOException {
taskMetrics = new TaskMetrics();
when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+ when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
@@ -190,6 +191,7 @@ public Tuple2 answer(
});
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+ when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
when(shuffleDep.serializer()).thenReturn(Option.apply(serializer));
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
@@ -473,62 +475,22 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception
@Test
public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
- // Use a custom serializer so that we have exact control over the size of serialized data.
- final Serializer byteArraySerializer = new Serializer() {
- @Override
- public SerializerInstance newInstance() {
- return new SerializerInstance() {
- @Override
- public SerializationStream serializeStream(final OutputStream s) {
- return new SerializationStream() {
- @Override
- public void flush() { }
-
- @Override
- public SerializationStream writeObject(T t, ClassTag ev1) {
- byte[] bytes = (byte[]) t;
- try {
- s.write(bytes);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- return this;
- }
-
- @Override
- public void close() { }
- };
- }
- public ByteBuffer serialize(T t, ClassTag ev1) { return null; }
- public DeserializationStream deserializeStream(InputStream s) { return null; }
- public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; }
- public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; }
- };
- }
- };
- when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer));
final UnsafeShuffleWriter
@@ -461,6 +480,92 @@ for binarized_feature, in binarizedFeatures.collect():
+## PCA
+
+[PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components.
+
+
+
+See the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.feature.PCA) for API details.
+{% highlight scala %}
+import org.apache.spark.ml.feature.PCA
+import org.apache.spark.mllib.linalg.Vectors
+
+val data = Array(
+ Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+)
+val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features")
+val pca = new PCA()
+ .setInputCol("features")
+ .setOutputCol("pcaFeatures")
+ .setK(3)
+ .fit(df)
+val pcaDF = pca.transform(df)
+val result = pcaDF.select("pcaFeatures")
+result.show()
+{% endhighlight %}
+
+
+
+See the [Java API documentation](api/java/org/apache/spark/ml/feature/PCA.html) for API details.
+{% highlight java %}
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.feature.PCA
+import org.apache.spark.ml.feature.PCAModel
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+JavaSparkContext jsc = ...
+SQLContext jsql = ...
+JavaRDD data = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})),
+ RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
+ RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
+));
+StructType schema = new StructType(new StructField[] {
+ new StructField("features", new VectorUDT(), false, Metadata.empty()),
+});
+DataFrame df = jsql.createDataFrame(data, schema);
+PCAModel pca = new PCA()
+ .setInputCol("features")
+ .setOutputCol("pcaFeatures")
+ .setK(3)
+ .fit(df);
+DataFrame result = pca.transform(df).select("pcaFeatures");
+result.show();
+{% endhighlight %}
+
+
## PolynomialExpansion
[Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality. The example below shows how to expand your features into a 3-degree polynomial space.
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 8c46adf256a9a..a03ab4356a413 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -355,6 +355,74 @@ jsc.stop();
{% endhighlight %}
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.param import Param, Params
+from pyspark.sql import Row, SQLContext
+
+sc = SparkContext(appName="SimpleParamsExample")
+sqlContext = SQLContext(sc)
+
+# Prepare training data.
+# We use LabeledPoint.
+# Spark SQL can convert RDDs of LabeledPoints into DataFrames.
+training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]),
+ LabeledPoint(0.0, [2.0, 1.0, -1.0]),
+ LabeledPoint(0.0, [2.0, 1.3, 1.0]),
+ LabeledPoint(1.0, [0.0, 1.2, -0.5])])
+
+# Create a LogisticRegression instance. This instance is an Estimator.
+lr = LogisticRegression(maxIter=10, regParam=0.01)
+# Print out the parameters, documentation, and any default values.
+print "LogisticRegression parameters:\n" + lr.explainParams() + "\n"
+
+# Learn a LogisticRegression model. This uses the parameters stored in lr.
+model1 = lr.fit(training.toDF())
+
+# Since model1 is a Model (i.e., a transformer produced by an Estimator),
+# we can view the parameters it used during fit().
+# This prints the parameter (name: value) pairs, where names are unique IDs for this
+# LogisticRegression instance.
+print "Model 1 was fit using parameters: "
+print model1.extractParamMap()
+
+# We may alternatively specify parameters using a Python dictionary as a paramMap
+paramMap = {lr.maxIter: 20}
+paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter.
+paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params.
+
+# You can combine paramMaps, which are python dictionaries.
+paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name
+paramMapCombined = paramMap.copy()
+paramMapCombined.update(paramMap2)
+
+# Now learn a new model using the paramMapCombined parameters.
+# paramMapCombined overrides all parameters set earlier via lr.set* methods.
+model2 = lr.fit(training.toDF(), paramMapCombined)
+print "Model 2 was fit using parameters: "
+print model2.extractParamMap()
+
+# Prepare test data
+test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]),
+ LabeledPoint(0.0, [ 3.0, 2.0, -0.1]),
+ LabeledPoint(1.0, [ 0.0, 2.2, -1.5])])
+
+# Make predictions on test data using the Transformer.transform() method.
+# LogisticRegression.transform will only use the 'features' column.
+# Note that model2.transform() outputs a "myProbability" column instead of the usual
+# 'probability' column since we renamed the lr.probabilityCol parameter previously.
+prediction = model2.transform(test.toDF())
+selected = prediction.select("features", "label", "myProbability", "prediction")
+for row in selected.collect():
+ print row
+
+sc.stop()
+{% endhighlight %}
+
+
## Example: Pipeline
@@ -561,7 +629,7 @@ test = sc.parallelize([(4L, "spark i j k"),
prediction = model.transform(test)
selected = prediction.select("id", "text", "prediction")
for row in selected.collect():
- print row
+ print(row)
sc.stop()
{% endhighlight %}
diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md
index 3aa040046fca5..f0e8d5495675d 100644
--- a/docs/mllib-data-types.md
+++ b/docs/mllib-data-types.md
@@ -372,12 +372,37 @@ long m = mat.numRows();
long n = mat.numCols();
{% endhighlight %}
+
+
+
+A [`RowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) can be
+created from an `RDD` of vectors.
+
+{% highlight python %}
+from pyspark.mllib.linalg.distributed import RowMatrix
+
+# Create an RDD of vectors.
+rows = sc.parallelize([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+
+# Create a RowMatrix from an RDD of vectors.
+mat = RowMatrix(rows)
+
+# Get its size.
+m = mat.numRows() # 4
+n = mat.numCols() # 3
+
+# Get the rows as an RDD of vectors again.
+rowsRDD = mat.rows
+{% endhighlight %}
+
+
### IndexedRowMatrix
An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by
-an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector.
+an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local
+vector.
@@ -431,7 +456,51 @@ long n = mat.numCols();
// Drop its row indices.
RowMatrix rowMat = mat.toRowMatrix();
{% endhighlight %}
-
+
+
+
+
+An [`IndexedRowMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRowMatrix)
+can be created from an `RDD` of `IndexedRow`s, where
+[`IndexedRow`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.IndexedRow) is a
+wrapper over `(long, vector)`. An `IndexedRowMatrix` can be converted to a `RowMatrix` by dropping
+its row indices.
+
+{% highlight python %}
+from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix
+
+# Create an RDD of indexed rows.
+# - This can be done explicitly with the IndexedRow class:
+indexedRows = sc.parallelize([IndexedRow(0, [1, 2, 3]),
+ IndexedRow(1, [4, 5, 6]),
+ IndexedRow(2, [7, 8, 9]),
+ IndexedRow(3, [10, 11, 12])])
+# - or by using (long, vector) tuples:
+indexedRows = sc.parallelize([(0, [1, 2, 3]), (1, [4, 5, 6]),
+ (2, [7, 8, 9]), (3, [10, 11, 12])])
+
+# Create an IndexedRowMatrix from an RDD of IndexedRows.
+mat = IndexedRowMatrix(indexedRows)
+
+# Get its size.
+m = mat.numRows() # 4
+n = mat.numCols() # 3
+
+# Get the rows as an RDD of IndexedRows.
+rowsRDD = mat.rows
+
+# Convert to a RowMatrix by dropping the row indices.
+rowMat = mat.toRowMatrix()
+
+# Convert to a CoordinateMatrix.
+coordinateMat = mat.toCoordinateMatrix()
+
+# Convert to a BlockMatrix.
+blockMat = mat.toBlockMatrix()
+{% endhighlight %}
+
+
+A [`CoordinateMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.CoordinateMatrix)
+can be created from an `RDD` of `MatrixEntry` entries, where
+[`MatrixEntry`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.MatrixEntry) is a
+wrapper over `(long, long, float)`. A `CoordinateMatrix` can be converted to a `RowMatrix` by
+calling `toRowMatrix`, or to an `IndexedRowMatrix` with sparse rows by calling `toIndexedRowMatrix`.
+
+{% highlight python %}
+from pyspark.mllib.linalg.distributed import CoordinateMatrix, MatrixEntry
+
+# Create an RDD of coordinate entries.
+# - This can be done explicitly with the MatrixEntry class:
+entries = sc.parallelize([MatrixEntry(0, 0, 1.2), MatrixEntry(1, 0, 2.1), MatrixEntry(6, 1, 3.7)])
+# - or using (long, long, float) tuples:
+entries = sc.parallelize([(0, 0, 1.2), (1, 0, 2.1), (2, 1, 3.7)])
+
+# Create an CoordinateMatrix from an RDD of MatrixEntries.
+mat = CoordinateMatrix(entries)
+
+# Get its size.
+m = mat.numRows() # 3
+n = mat.numCols() # 2
+
+# Get the entries as an RDD of MatrixEntries.
+entriesRDD = mat.entries
+
+# Convert to a RowMatrix.
+rowMat = mat.toRowMatrix()
+
+# Convert to an IndexedRowMatrix.
+indexedRowMat = mat.toIndexedRowMatrix()
+
+# Convert to a BlockMatrix.
+blockMat = mat.toBlockMatrix()
+{% endhighlight %}
+
+
+A [`BlockMatrix`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.BlockMatrix)
+can be created from an `RDD` of sub-matrix blocks, where a sub-matrix block is a
+`((blockRowIndex, blockColIndex), sub-matrix)` tuple.
+
+{% highlight python %}
+from pyspark.mllib.linalg import Matrices
+from pyspark.mllib.linalg.distributed import BlockMatrix
+
+# Create an RDD of sub-matrix blocks.
+blocks = sc.parallelize([((0, 0), Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])),
+ ((1, 0), Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12]))])
+
+# Create a BlockMatrix from an RDD of sub-matrix blocks.
+mat = BlockMatrix(blocks, 3, 2)
+
+# Get its size.
+m = mat.numRows() # 6
+n = mat.numCols() # 2
+
+# Get the blocks as an RDD of sub-matrix blocks.
+blocksRDD = mat.blocks
+
+# Convert to a LocalMatrix.
+localMat = mat.toLocalMatrix()
+
+# Convert to an IndexedRowMatrix.
+indexedRowMat = mat.toIndexedRowMatrix()
+
+# Convert to a CoordinateMatrix.
+coordinateMat = mat.toCoordinateMatrix()
+{% endhighlight %}
+
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
index 4ca0bb06b26a6..7066d5c97418c 100644
--- a/docs/mllib-evaluation-metrics.md
+++ b/docs/mllib-evaluation-metrics.md
@@ -302,10 +302,10 @@ predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp
metrics = BinaryClassificationMetrics(predictionAndLabels)
# Area under precision-recall curve
-print "Area under PR = %s" % metrics.areaUnderPR
+print("Area under PR = %s" % metrics.areaUnderPR)
# Area under ROC curve
-print "Area under ROC = %s" % metrics.areaUnderROC
+print("Area under ROC = %s" % metrics.areaUnderROC)
{% endhighlight %}
@@ -606,24 +606,24 @@ metrics = MulticlassMetrics(predictionAndLabels)
precision = metrics.precision()
recall = metrics.recall()
f1Score = metrics.fMeasure()
-print "Summary Stats"
-print "Precision = %s" % precision
-print "Recall = %s" % recall
-print "F1 Score = %s" % f1Score
+print("Summary Stats")
+print("Precision = %s" % precision)
+print("Recall = %s" % recall)
+print("F1 Score = %s" % f1Score)
# Statistics by class
labels = data.map(lambda lp: lp.label).distinct().collect()
for label in sorted(labels):
- print "Class %s precision = %s" % (label, metrics.precision(label))
- print "Class %s recall = %s" % (label, metrics.recall(label))
- print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))
+ print("Class %s precision = %s" % (label, metrics.precision(label)))
+ print("Class %s recall = %s" % (label, metrics.recall(label)))
+ print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)))
# Weighted stats
-print "Weighted recall = %s" % metrics.weightedRecall
-print "Weighted precision = %s" % metrics.weightedPrecision
-print "Weighted F(1) Score = %s" % metrics.weightedFMeasure()
-print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)
-print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate
+print("Weighted recall = %s" % metrics.weightedRecall)
+print("Weighted precision = %s" % metrics.weightedPrecision)
+print("Weighted F(1) Score = %s" % metrics.weightedFMeasure())
+print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5))
+print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate)
{% endhighlight %}
@@ -881,28 +881,28 @@ scoreAndLabels = sc.parallelize([
metrics = MultilabelMetrics(scoreAndLabels)
# Summary stats
-print "Recall = %s" % metrics.recall()
-print "Precision = %s" % metrics.precision()
-print "F1 measure = %s" % metrics.f1Measure()
-print "Accuracy = %s" % metrics.accuracy
+print("Recall = %s" % metrics.recall())
+print("Precision = %s" % metrics.precision())
+print("F1 measure = %s" % metrics.f1Measure())
+print("Accuracy = %s" % metrics.accuracy)
# Individual label stats
labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect()
for label in labels:
- print "Class %s precision = %s" % (label, metrics.precision(label))
- print "Class %s recall = %s" % (label, metrics.recall(label))
- print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))
+ print("Class %s precision = %s" % (label, metrics.precision(label)))
+ print("Class %s recall = %s" % (label, metrics.recall(label)))
+ print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)))
# Micro stats
-print "Micro precision = %s" % metrics.microPrecision
-print "Micro recall = %s" % metrics.microRecall
-print "Micro F1 measure = %s" % metrics.microF1Measure
+print("Micro precision = %s" % metrics.microPrecision)
+print("Micro recall = %s" % metrics.microRecall)
+print("Micro F1 measure = %s" % metrics.microF1Measure)
# Hamming loss
-print "Hamming loss = %s" % metrics.hammingLoss
+print("Hamming loss = %s" % metrics.hammingLoss)
# Subset accuracy
-print "Subset accuracy = %s" % metrics.subsetAccuracy
+print("Subset accuracy = %s" % metrics.subsetAccuracy)
{% endhighlight %}
@@ -1283,10 +1283,10 @@ scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1])
metrics = RegressionMetrics(scoreAndLabels)
# Root mean sqaured error
-print "RMSE = %s" % metrics.rootMeanSquaredError
+print("RMSE = %s" % metrics.rootMeanSquaredError)
# R-squared
-print "R-squared = %s" % metrics.r2
+print("R-squared = %s" % metrics.r2)
{% endhighlight %}
@@ -1479,17 +1479,17 @@ valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.l
metrics = RegressionMetrics(valuesAndPreds)
# Squared Error
-print "MSE = %s" % metrics.meanSquaredError
-print "RMSE = %s" % metrics.rootMeanSquaredError
+print("MSE = %s" % metrics.meanSquaredError)
+print("RMSE = %s" % metrics.rootMeanSquaredError)
# R-squared
-print "R-squared = %s" % metrics.r2
+print("R-squared = %s" % metrics.r2)
# Mean absolute error
-print "MAE = %s" % metrics.meanAbsoluteError
+print("MAE = %s" % metrics.meanAbsoluteError)
# Explained variance
-print "Explained variance = %s" % metrics.explainedVariance
+print("Explained variance = %s" % metrics.explainedVariance)
{% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index a69e41e2a1936..de86aba2ae627 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -221,7 +221,7 @@ model = word2vec.fit(inp)
synonyms = model.findSynonyms('china', 40)
for word, cosine_distance in synonyms:
- print "{}: {}".format(word, cosine_distance)
+ print("{}: {}".format(word, cosine_distance))
{% endhighlight %}
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md
index de5d6485f9b5f..be04d0b4b53a8 100644
--- a/docs/mllib-statistics.md
+++ b/docs/mllib-statistics.md
@@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors
# Compute column summary statistics.
summary = Statistics.colStats(mat)
-print summary.mean()
-print summary.variance()
-print summary.numNonzeros()
+print(summary.mean())
+print(summary.variance())
+print(summary.numNonzeros())
{% endhighlight %}
@@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie
# Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a
# method is not specified, Pearson's method will be used by default.
-print Statistics.corr(seriesX, seriesY, method="pearson")
+print(Statistics.corr(seriesX, seriesY, method="pearson"))
data = ... # an RDD of Vectors
# calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method.
# If a method is not specified, Pearson's method will be used by default.
-print Statistics.corr(data, method="pearson")
+print(Statistics.corr(data, method="pearson"))
{% endhighlight %}
@@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events
# compute the goodness of fit. If a second vector to test against is not supplied as a parameter,
# the test runs against a uniform distribution.
goodnessOfFitTestResult = Statistics.chiSqTest(vec)
-print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom,
- # test statistic, the method used, and the null hypothesis.
+print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom,
+ # test statistic, the method used, and the null hypothesis.
mat = Matrices.dense(...) # a contingency matrix
# conduct Pearson's independence test on the input contingency matrix
independenceTestResult = Statistics.chiSqTest(mat)
-print independenceTestResult # summary of the test including the p-value, degrees of freedom...
+print(independenceTestResult) # summary of the test including the p-value, degrees of freedom...
obs = sc.parallelize(...) # LabeledPoint(feature, label) .
@@ -415,8 +415,8 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) .
featureTestResults = Statistics.chiSqTest(obs)
for i, result in enumerate(featureTestResults):
- print "Column $d:" % (i + 1)
- print result
+ print("Column $d:" % (i + 1))
+ print(result)
{% endhighlight %}
diff --git a/docs/monitoring.md b/docs/monitoring.md
index bcf885fe4e681..cedceb2958023 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -48,7 +48,7 @@ follows:
Environment Variable
Meaning
SPARK_DAEMON_MEMORY
-
Memory to allocate to the history server (default: 512m).
+
Memory to allocate to the history server (default: 1g).
SPARK_DAEMON_JAVA_OPTS
diff --git a/docs/quick-start.md b/docs/quick-start.md
index bb39e4111f244..ce2cc9d2169cd 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -406,7 +406,7 @@ logData = sc.textFile(logFile).cache()
numAs = logData.filter(lambda s: 'a' in s).count()
numBs = logData.filter(lambda s: 'b' in s).count()
-print "Lines with a: %i, lines with b: %i" % (numAs, numBs)
+print("Lines with a: %i, lines with b: %i" % (numAs, numBs))
{% endhighlight %}
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index debdd2adf22d6..cfd219ab02e26 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -216,6 +216,20 @@ node. Please refer to [Hadoop on Mesos](https://github.com/mesos/hadoop).
In either case, HDFS runs separately from Hadoop MapReduce, without being scheduled through Mesos.
+# Dynamic Resource Allocation with Mesos
+
+Mesos supports dynamic allocation only with coarse grain mode, which can resize the number of executors based on statistics
+of the application. While dynamic allocation supports both scaling up and scaling down the number of executors, the coarse grain scheduler only supports scaling down
+since it is already designed to run one executor per slave with the configured amount of resources. However, after scaling down the number of executors the coarse grain scheduler
+can scale back up to the same amount of executors when Spark signals more executors are needed.
+
+Users that like to utilize this feature should launch the Mesos Shuffle Service that
+provides shuffle data cleanup functionality on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's
+termination. To launch/stop the Mesos Shuffle Service please use the provided sbin/start-mesos-shuffle-service.sh and sbin/stop-mesos-shuffle-service.sh
+scripts accordingly.
+
+The Shuffle Service is expected to be running on each slave node that will run Spark executors. One way to easily achieve this with Mesos
+is to launch the Shuffle Service with Marathon with a unique host constraint.
# Configuration
@@ -306,6 +320,14 @@ See the [configuration page](configuration.html) for information on Spark config
the final overhead will be this value.
+
+
spark.mesos.uris
+
(none)
+
+ A list of URIs to be downloaded to the sandbox when driver or executor is launched by Mesos.
+ This applies to both coarse-grain and fine-grain mode.
+
+
spark.mesos.principal
Framework principal to authenticate to Mesos
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index cac08a91b97d9..ec32c419b7c51 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -199,7 +199,7 @@ If you need a reference to the proper location to put log files in the YARN so t
spark.executor.instances
2
- The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled.
+ The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 4f71fbc086cd0..2fe9ec3542b28 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -152,7 +152,7 @@ You can optionally configure the cluster further by setting environment variable
SPARK_DAEMON_MEMORY
-
Memory to allocate to the Spark master and worker daemons themselves (default: 512m).
+
Memory to allocate to the Spark master and worker daemons themselves (default: 1g).
SPARK_DAEMON_JAVA_OPTS
diff --git a/docs/sparkr.md b/docs/sparkr.md
index 4385a4eeacd5c..7139d16b4a068 100644
--- a/docs/sparkr.md
+++ b/docs/sparkr.md
@@ -11,7 +11,8 @@ title: SparkR (R on Spark)
SparkR is an R package that provides a light-weight frontend to use Apache Spark from R.
In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that
supports operations like selection, filtering, aggregation etc. (similar to R data frames,
-[dplyr](https://github.com/hadley/dplyr)) but on large datasets.
+[dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed
+machine learning using MLlib.
# SparkR DataFrames
@@ -230,3 +231,37 @@ head(teenagers)
{% endhighlight %}
+
+# Machine Learning
+
+SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', '+', and '-'. The example below shows the use of building a gaussian GLM model using SparkR.
+
+
+{% highlight r %}
+# Create the DataFrame
+df <- createDataFrame(sqlContext, iris)
+
+# Fit a linear model over the dataset.
+model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian")
+
+# Model coefficients are returned in a similar format to R's native glm().
+summary(model)
+##$coefficients
+## Estimate
+##(Intercept) 2.2513930
+##Sepal_Width 0.8035609
+##Species_versicolor 1.4587432
+##Species_virginica 1.9468169
+
+# Make predictions based on the model.
+predictions <- predict(model, newData = df)
+head(select(predictions, "Sepal_Length", "prediction"))
+## Sepal_Length prediction
+##1 5.1 5.063856
+##2 4.9 4.662076
+##3 4.7 4.822788
+##4 4.6 4.742432
+##5 5.0 5.144212
+##6 5.4 5.385281
+{% endhighlight %}
+
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 95945eb7fc8a0..6c317175d3278 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -570,7 +570,7 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 1
# The results of SQL queries are RDDs and support all the normal RDD operations.
teenNames = teenagers.map(lambda p: "Name: " + p.name)
for teenName in teenNames.collect():
- print teenName
+ print(teenName)
{% endhighlight %}
@@ -752,7 +752,7 @@ results = sqlContext.sql("SELECT name FROM people")
# The results of SQL queries are RDDs and support all the normal RDD operations.
names = results.map(lambda p: "Name: " + p.name)
for name in names.collect():
- print name
+ print(name)
{% endhighlight %}
@@ -1006,7 +1006,7 @@ parquetFile.registerTempTable("parquetFile");
teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19")
teenNames = teenagers.map(lambda p: "Name: " + p.name)
for teenName in teenNames.collect():
- print teenName
+ print(teenName)
{% endhighlight %}
@@ -1884,12 +1884,11 @@ that these options will be deprecated in future release as more optimizations ar
-
spark.sql.codegen
-
false
+
spark.sql.tungsten.enabled
+
true
- When true, code will be dynamically generated at runtime for expression evaluation in a specific
- query. For some queries with complicated expression this option can lead to significant speed-ups.
- However, for simple queries this can actually slow down query execution.
+ When true, use the optimized Tungsten physical execution backend which explicitly manages memory
+ and dynamically generates bytecode for expression evaluation.
@@ -1901,7 +1900,7 @@ that these options will be deprecated in future release as more optimizations ar
spark.sql.planner.externalSort
-
false
+
true
When true, performs sorts spilling to disk as needed otherwise sort each partition in memory.
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index 775d508d4879b..7571e22575efd 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -152,7 +152,7 @@ Next, we discuss how to use this approach in your streaming application.
// Hold a reference to the current offset ranges, so it can be used downstream
- final AtomicReference offsetRanges = new AtomicReference();
+ final AtomicReference offsetRanges = new AtomicReference<>();
directKafkaStream.transformToPair(
new Function, JavaPairRDD>() {
diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md
index aa9749afbc867..a7bcaec6fcd84 100644
--- a/docs/streaming-kinesis-integration.md
+++ b/docs/streaming-kinesis-integration.md
@@ -51,6 +51,17 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m
See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html)
and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example.
+
+
+ from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
+
+ kinesisStream = KinesisUtils.createStream(
+ streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL],
+ [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2)
+
+ See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils)
+ and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example.
+
@@ -135,6 +146,14 @@ To run the example,
bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
+
+
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 2f3013b533eb0..c59d936b43c88 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea
{:.no_toc}
Python API As of Spark {{site.SPARK_VERSION_SHORT}},
-out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future.
+out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future.
This category of sources require interfacing with external non-Spark libraries, some of them with
complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts
@@ -1141,7 +1141,7 @@ val joinedStream = stream1.join(stream2)
{% highlight java %}
JavaPairDStream stream1 = ...
JavaPairDStream stream2 = ...
-JavaPairDStream joinedStream = stream1.join(stream2);
+JavaPairDStream> joinedStream = stream1.join(stream2);
{% endhighlight %}
@@ -1525,7 +1525,7 @@ def getSqlContextInstance(sparkContext):
words = ... # DStream of strings
def process(time, rdd):
- print "========= %s =========" % str(time)
+ print("========= %s =========" % str(time))
try:
# Get the singleton instance of SQLContext
sqlContext = getSqlContextInstance(rdd.context)
diff --git a/docs/tuning.md b/docs/tuning.md
index 572c7270e4999..6936912a6be54 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -240,7 +240,7 @@ worth optimizing.
## Data Locality
Data locality can have a major impact on the performance of Spark jobs. If data and the code that
-operates on it are together than computation tends to be fast. But if code and data are separated,
+operates on it are together then computation tends to be fast. But if code and data are separated,
one must move to the other. Typically it is faster to ship serialized code from place to place than
a chunk of data because code size is much smaller than data. Spark builds its scheduling around
this general principle of data locality.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index ccf922d9371fb..11fd7ee0ec8df 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -90,7 +90,7 @@
DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark"
# Default location to get the spark-ec2 scripts (and ami-list) from
-DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/mesos/spark-ec2"
+DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2"
DEFAULT_SPARK_EC2_BRANCH = "branch-1.4"
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 9df26ffca5775..3f1fe900b0008 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -230,6 +230,7 @@ public Vector predictRaw(Vector features) {
*/
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
- return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra);
+ return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra)
+ .setParent(parent());
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
new file mode 100644
index 0000000000000..be2bf0c7b465c
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.regex.Pattern;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.ml.clustering.KMeansModel;
+import org.apache.spark.ml.clustering.KMeans;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+
+/**
+ * An example demonstrating a k-means clustering.
+ * Run with
+ *
+ */
+public class JavaKMeansExample {
+
+ private static class ParsePoint implements Function {
+ private static final Pattern separator = Pattern.compile(" ");
+
+ @Override
+ public Row call(String line) {
+ String[] tok = separator.split(line);
+ double[] point = new double[tok.length];
+ for (int i = 0; i < tok.length; ++i) {
+ point[i] = Double.parseDouble(tok[i]);
+ }
+ Vector[] points = {Vectors.dense(point)};
+ return new GenericRow(points);
+ }
+ }
+
+ public static void main(String[] args) {
+ if (args.length != 2) {
+ System.err.println("Usage: ml.JavaKMeansExample ");
+ System.exit(1);
+ }
+ String inputFile = args[0];
+ int k = Integer.parseInt(args[1]);
+
+ // Parses the arguments
+ SparkConf conf = new SparkConf().setAppName("JavaKMeansExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext sqlContext = new SQLContext(jsc);
+
+ // Loads data
+ JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint());
+ StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
+ StructType schema = new StructType(fields);
+ DataFrame dataset = sqlContext.createDataFrame(points, schema);
+
+ // Trains a k-means model
+ KMeans kmeans = new KMeans()
+ .setK(k);
+ KMeansModel model = kmeans.fit(dataset);
+
+ // Shows the result
+ Vector[] centers = model.clusterCenters();
+ System.out.println("Cluster Centers: ");
+ for (Vector center: centers) {
+ System.out.println(center);
+ }
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
index 75063dbf800d8..e7f2f6f615070 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
@@ -178,6 +178,7 @@ private static Params parse(String[] args) {
return params;
}
+ @SuppressWarnings("static")
private static Options generateCommandlineOptions() {
Option input = OptionBuilder.withArgName("input")
.hasArg()
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index dac649d1d5ae6..94beeced3d479 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -77,7 +77,8 @@ public static void main(String[] args) {
ParamMap paramMap = new ParamMap();
paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
- paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
+ double thresholds[] = {0.45, 0.55};
+ paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params.
// One can also combine ParamMaps.
ParamMap paramMap2 = new ParamMap();
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index dbf2ef02d7b76..02f58f48b07ab 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -85,7 +85,7 @@ public Optional call(List values, Optional state) {
@SuppressWarnings("unchecked")
List> tuples = Arrays.asList(new Tuple2("hello", 1),
new Tuple2("world", 1));
- JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples);
+ JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples);
JavaReceiverInputDStream lines = ssc.socketTextStream(
args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2);
@@ -107,7 +107,7 @@ public Tuple2 call(String s) {
// This will give a Dstream made of state (which is the cumulative count of the words)
JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction,
- new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD);
+ new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD);
stateDstream.print();
ssc.start();
diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py
new file mode 100644
index 0000000000000..150dadd42f33e
--- /dev/null
+++ b/examples/src/main/python/ml/kmeans_example.py
@@ -0,0 +1,71 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+import sys
+import re
+
+import numpy as np
+from pyspark import SparkContext
+from pyspark.ml.clustering import KMeans, KMeansModel
+from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
+from pyspark.sql import SQLContext
+from pyspark.sql.types import Row, StructField, StructType
+
+"""
+A simple example demonstrating a k-means clustering.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/kmeans_example.py
+
+This example requires NumPy (http://www.numpy.org/).
+"""
+
+
+def parseVector(line):
+ array = np.array([float(x) for x in line.split(' ')])
+ return _convert_to_vector(array)
+
+
+if __name__ == "__main__":
+
+ FEATURES_COL = "features"
+
+ if len(sys.argv) != 3:
+ print("Usage: kmeans_example.py ", file=sys.stderr)
+ exit(-1)
+ path = sys.argv[1]
+ k = sys.argv[2]
+
+ sc = SparkContext(appName="PythonKMeansExample")
+ sqlContext = SQLContext(sc)
+
+ lines = sc.textFile(path)
+ data = lines.map(parseVector)
+ row_rdd = data.map(lambda x: Row(x))
+ schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
+ df = sqlContext.createDataFrame(row_rdd, schema)
+
+ kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
+ model = kmeans.fit(df)
+ centers = model.clusterCenters()
+
+ print("Cluster Centers: ")
+ for center in centers:
+ print(center)
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
index a9f29dab2d602..2d6d115d54d02 100644
--- a/examples/src/main/python/ml/simple_params_example.py
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -70,7 +70,7 @@
# We may alternatively specify parameters using a parameter map.
# paramMap overrides all lr parameters set earlier.
- paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+ paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"}
# Now learn a new model using the new parameters.
model2 = lr.fit(training, paramMap)
diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py
new file mode 100644
index 0000000000000..617ce5ea6775e
--- /dev/null
+++ b/examples/src/main/python/streaming/mqtt_wordcount.py
@@ -0,0 +1,58 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ A sample wordcount with MqttStream stream
+ Usage: mqtt_wordcount.py
+
+ To run this in your local machine, you need to setup a MQTT broker and publisher first,
+ Mosquitto is one of the open source MQTT Brokers, see
+ http://mosquitto.org/
+ Eclipse paho project provides number of clients and utilities for working with MQTT, see
+ http://www.eclipse.org/paho/#getting-started
+
+ and then run the example
+ `$ bin/spark-submit --jars external/mqtt-assembly/target/scala-*/\
+ spark-streaming-mqtt-assembly-*.jar examples/src/main/python/streaming/mqtt_wordcount.py \
+ tcp://localhost:1883 foo`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+from pyspark.streaming.mqtt import MQTTUtils
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: mqtt_wordcount.py "
+ exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingMQTTWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ brokerUrl = sys.argv[1]
+ topic = sys.argv[2]
+
+ lines = MQTTUtils.createStream(ssc, brokerUrl, topic)
+ counts = lines.flatMap(lambda line: line.split(" ")) \
+ .map(lambda word: (word, 1)) \
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 78f31b4ffe56a..340c3559b15ef 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -179,7 +179,7 @@ private class MyLogisticRegressionModel(
* This is used for the default implementation of [[transform()]].
*/
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
- copyValues(new MyLogisticRegressionModel(uid, weights), extra)
+ copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent)
}
}
// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
new file mode 100644
index 0000000000000..5ce38462d1181
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.ml.clustering.KMeans
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.types.{StructField, StructType}
+
+
+/**
+ * An example demonstrating a k-means clustering.
+ * Run with
+ * {{{
+ * bin/run-example ml.KMeansExample
+ * }}}
+ */
+object KMeansExample {
+
+ final val FEATURES_COL = "features"
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ // scalastyle:off println
+ System.err.println("Usage: ml.KMeansExample ")
+ // scalastyle:on println
+ System.exit(1)
+ }
+ val input = args(0)
+ val k = args(1).toInt
+
+ // Creates a Spark context and a SQL context
+ val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+
+ // Loads data
+ val rowRDD = sc.textFile(input).filter(_.nonEmpty)
+ .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
+ val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
+ val dataset = sqlContext.createDataFrame(rowRDD, schema)
+
+ // Trains a k-means model
+ val kmeans = new KMeans()
+ .setK(k)
+ .setFeaturesCol(FEATURES_COL)
+ val model = kmeans.fit(dataset)
+
+ // Shows the result
+ // scalastyle:off println
+ println("Final Centers: ")
+ model.clusterCenters.foreach(println)
+ // scalastyle:on println
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 58d7b67674ff7..f4d1fe57856a1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -70,7 +70,7 @@ object SimpleParamsExample {
// which supports several methods for specifying parameters.
val paramMap = ParamMap(lr.maxIter -> 20)
paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
- paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
+ paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params.
// One can also combine ParamMaps.
val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 48a1933d92f85..8a177077775c6 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -29,7 +29,8 @@ import org.apache.spark.{Logging, SparkException}
import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
-import org.apache.spark.streaming.scheduler.StreamInputInfo
+import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
/**
* A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where
@@ -61,7 +62,7 @@ class DirectKafkaInputDStream[
val kafkaParams: Map[String, String],
val fromOffsets: Map[TopicAndPartition, Long],
messageHandler: MessageAndMetadata[K, V] => R
-) extends InputDStream[R](ssc_) with Logging {
+ ) extends InputDStream[R](ssc_) with Logging {
val maxRetries = context.sparkContext.getConf.getInt(
"spark.streaming.kafka.maxRetries", 1)
@@ -71,14 +72,35 @@ class DirectKafkaInputDStream[
protected[streaming] override val checkpointData =
new DirectKafkaInputDStreamCheckpointData
+
+ /**
+ * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
+ */
+ override protected[streaming] val rateController: Option[RateController] = {
+ if (RateController.isBackPressureEnabled(ssc.conf)) {
+ Some(new DirectKafkaRateController(id,
+ RateEstimator.create(ssc.conf, ssc_.graph.batchDuration)))
+ } else {
+ None
+ }
+ }
+
protected val kc = new KafkaCluster(kafkaParams)
- protected val maxMessagesPerPartition: Option[Long] = {
- val ratePerSec = context.sparkContext.getConf.getInt(
+ private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
"spark.streaming.kafka.maxRatePerPartition", 0)
- if (ratePerSec > 0) {
+ protected def maxMessagesPerPartition: Option[Long] = {
+ val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
+ val numPartitions = currentOffsets.keys.size
+
+ val effectiveRateLimitPerPartition = estimatedRateLimit
+ .filter(_ > 0)
+ .map(limit => Math.min(maxRateLimitPerPartition, (limit / numPartitions)))
+ .getOrElse(maxRateLimitPerPartition)
+
+ if (effectiveRateLimitPerPartition > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
- Some((secsPerBatch * ratePerSec).toLong)
+ Some((secsPerBatch * effectiveRateLimitPerPartition).toLong)
} else {
None
}
@@ -170,11 +192,18 @@ class DirectKafkaInputDStream[
val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics))
batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) =>
- logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
- generatedRDDs += t -> new KafkaRDD[K, V, U, T, R](
- context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler)
+ logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
+ generatedRDDs += t -> new KafkaRDD[K, V, U, T, R](
+ context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler)
}
}
}
+ /**
+ * A RateController to retrieve the rate from RateEstimator.
+ */
+ private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator)
+ extends RateController(id, estimator) {
+ override def publish(rate: Long): Unit = ()
+ }
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index 1a9d78c0d4f59..ea5f842c6cafe 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -197,7 +197,11 @@ class KafkaRDD[
.dropWhile(_.offset < requestOffset)
}
- override def close(): Unit = consumer.close()
+ override def close(): Unit = {
+ if (consumer != null) {
+ consumer.close()
+ }
+ }
override def getNext(): R = {
if (iter == null || !iter.hasNext) {
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
index f326e7f1f6f8d..2f8981d4898bd 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
@@ -42,16 +42,16 @@ trait HasOffsetRanges {
* :: Experimental ::
* Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class
* can be created with `OffsetRange.create()`.
+ * @param topic Kafka topic name
+ * @param partition Kafka partition id
+ * @param fromOffset Inclusive starting offset
+ * @param untilOffset Exclusive ending offset
*/
@Experimental
final class OffsetRange private(
- /** Kafka topic name */
val topic: String,
- /** Kafka partition id */
val partition: Int,
- /** inclusive starting offset */
val fromOffset: Long,
- /** exclusive ending offset */
val untilOffset: Long) extends Serializable {
import OffsetRange.OffsetRangeTuple
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
index 75f0dfc22b9dc..764d170934aa6 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
@@ -96,7 +96,7 @@ class ReliableKafkaReceiver[
blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]()
// Initialize the block generator for storing Kafka message.
- blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf)
+ blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler)
if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") {
logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " +
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
index 02cd24a35906f..9db07d0507fea 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -70,7 +70,7 @@ public void testKafkaStream() throws InterruptedException {
final String topic1 = "topic1";
final String topic2 = "topic2";
// hold a reference to the current offset ranges, so it can be used downstream
- final AtomicReference offsetRanges = new AtomicReference();
+ final AtomicReference offsetRanges = new AtomicReference<>();
String[] topic1data = createTopicAndSendData(topic1);
String[] topic2data = createTopicAndSendData(topic2);
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 5b3c79444aa68..02225d5aa7cc5 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -20,6 +20,9 @@ package org.apache.spark.streaming.kafka
import java.io.File
import java.util.concurrent.atomic.AtomicLong
+import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
@@ -350,6 +353,77 @@ class DirectKafkaStreamSuite
ssc.stop()
}
+ test("using rate controller") {
+ val topic = "backpressure"
+ val topicPartition = TopicAndPartition(topic, 0)
+ kafkaTestUtils.createTopic(topic)
+ val kafkaParams = Map(
+ "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
+ "auto.offset.reset" -> "smallest"
+ )
+
+ val batchIntervalMilliseconds = 100
+ val estimator = new ConstantEstimator(100)
+ val messageKeys = (1 to 200).map(_.toString)
+ val messages = messageKeys.map((_, 1)).toMap
+
+ val sparkConf = new SparkConf()
+ // Safe, even with streaming, because we're using the direct API.
+ // Using 1 core is useful to make the test more predictable.
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+
+ val kafkaStream = withClue("Error creating direct stream") {
+ val kc = new KafkaCluster(kafkaParams)
+ val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
+ val m = kc.getEarliestLeaderOffsets(Set(topicPartition))
+ .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset))
+
+ new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
+ ssc, kafkaParams, m, messageHandler) {
+ override protected[streaming] val rateController =
+ Some(new DirectKafkaRateController(id, estimator))
+ }
+ }
+
+ val collectedData =
+ new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]]
+
+ // Used for assertion failure messages.
+ def dataToString: String =
+ collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}")
+
+ // This is to collect the raw data received from Kafka
+ kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
+ val data = rdd.map { _._2 }.collect()
+ collectedData += data
+ }
+
+ ssc.start()
+
+ // Try different rate limits.
+ // Send data to Kafka and wait for arrays of data to appear matching the rate.
+ Seq(100, 50, 20).foreach { rate =>
+ collectedData.clear() // Empty this buffer on each pass.
+ estimator.updateRate(rate) // Set a new rate.
+ // Expect blocks of data equal to "rate", scaled by the interval length in secs.
+ val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001)
+ kafkaTestUtils.sendMessages(topic, messages)
+ eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) {
+ // Assert that rate estimator values are used to determine maxMessagesPerPartition.
+ // Funky "-" in message makes the complete assertion message read better.
+ assert(collectedData.exists(_.size == expectedSize),
+ s" - No arrays of size $expectedSize for rate $rate found in $dataToString")
+ }
+ }
+
+ ssc.stop()
+ }
+
/** Get the generated offset ranges from the DirectKafkaStream */
private def getOffsetRanges[K, V](
kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
@@ -381,3 +455,18 @@ object DirectKafkaStreamSuite {
}
}
}
+
+private[streaming] class ConstantEstimator(@volatile private var rate: Long)
+ extends RateEstimator {
+
+ def updateRate(newRate: Long): Unit = {
+ rate = newRate
+ }
+
+ def compute(
+ time: Long,
+ elements: Long,
+ processingDelay: Long,
+ schedulingDelay: Long): Option[Double] = Some(rate)
+}
+
diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml
new file mode 100644
index 0000000000000..9c94473053d96
--- /dev/null
+++ b/external/mqtt-assembly/pom.xml
@@ -0,0 +1,102 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.10
+ 1.5.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-streaming-mqtt-assembly_2.10
+ jar
+ Spark Project External MQTT Assembly
+ http://spark.apache.org/
+
+
+ streaming-mqtt-assembly
+
+
+
+
+ org.apache.spark
+ spark-streaming-mqtt_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+ ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar
+
+
+ *:*
+
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+ package
+
+ shade
+
+
+
+
+
+ reference.conf
+
+
+ log4j.properties
+
+
+
+
+
+
+
+
+
+
+
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 0e41e5781784b..69b309876a0db 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -78,5 +78,33 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
+
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+
+
+ test-jar-with-dependencies
+ package
+
+ single
+
+
+
+ spark-streaming-mqtt-test-${project.version}
+ ${project.build.directory}/scala-${scala.binary.version}/
+ false
+
+ false
+
+ src/main/assembly/assembly.xml
+
+
+
+
+
+
diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml
new file mode 100644
index 0000000000000..ecab5b360eb3e
--- /dev/null
+++ b/external/mqtt/src/main/assembly/assembly.xml
@@ -0,0 +1,44 @@
+
+
+ test-jar-with-dependencies
+
+ jar
+
+ false
+
+
+
+ ${project.build.directory}/scala-${scala.binary.version}/test-classes
+ /
+
+
+
+
+
+ true
+ test
+ true
+
+ org.apache.hadoop:*:jar
+ org.apache.zookeeper:*:jar
+ org.apache.avro:*:jar
+
+
+
+
+
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
index 1142d0f56ba34..38a1114863d15 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
@@ -74,3 +74,19 @@ object MQTTUtils {
createStream(jssc.ssc, brokerUrl, topic, storageLevel)
}
}
+
+/**
+ * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and
+ * function so that it can be easily instantiated and called from Python's MQTTUtils.
+ */
+private class MQTTUtilsPythonHelper {
+
+ def createStream(
+ jssc: JavaStreamingContext,
+ brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel
+ ): JavaDStream[String] = {
+ MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
+ }
+}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index c4bf5aa7869bb..a6a9249db8ed7 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -17,46 +17,30 @@
package org.apache.spark.streaming.mqtt
-import java.net.{URI, ServerSocket}
-import java.util.concurrent.CountDownLatch
-import java.util.concurrent.TimeUnit
-
import scala.concurrent.duration._
import scala.language.postfixOps
-import org.apache.activemq.broker.{TransportConnector, BrokerService}
-import org.apache.commons.lang3.RandomUtils
-import org.eclipse.paho.client.mqttv3._
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.streaming.scheduler.StreamingListener
-import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.util.Utils
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
private val batchDuration = Milliseconds(500)
private val master = "local[2]"
private val framework = this.getClass.getSimpleName
- private val freePort = findFreePort()
- private val brokerUri = "//localhost:" + freePort
private val topic = "def"
- private val persistenceDir = Utils.createTempDir()
private var ssc: StreamingContext = _
- private var broker: BrokerService = _
- private var connector: TransportConnector = _
+ private var mqttTestUtils: MQTTTestUtils = _
before {
ssc = new StreamingContext(master, framework, batchDuration)
- setupMQTT()
+ mqttTestUtils = new MQTTTestUtils
+ mqttTestUtils.setup()
}
after {
@@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
ssc.stop()
ssc = null
}
- Utils.deleteRecursively(persistenceDir)
- tearDownMQTT()
+ if (mqttTestUtils != null) {
+ mqttTestUtils.teardown()
+ mqttTestUtils = null
+ }
}
test("mqtt input stream") {
val sendMessage = "MQTT demo for spark streaming"
- val receiveStream =
- MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
+ val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic,
+ StorageLevel.MEMORY_ONLY)
+
@volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd =>
if (rdd.collect.length > 0) {
@@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
receiveMessage
}
}
- ssc.start()
- // wait for the receiver to start before publishing data, or we risk failing
- // the test nondeterministically. See SPARK-4631
- waitForReceiverToStart()
+ ssc.start()
- publishData(sendMessage)
+ // Retry it because we don't know when the receiver will start.
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
+ mqttTestUtils.publishData(topic, sendMessage)
assert(sendMessage.equals(receiveMessage(0)))
}
ssc.stop()
}
-
- private def setupMQTT() {
- broker = new BrokerService()
- broker.setDataDirectoryFile(Utils.createTempDir())
- connector = new TransportConnector()
- connector.setName("mqtt")
- connector.setUri(new URI("mqtt:" + brokerUri))
- broker.addConnector(connector)
- broker.start()
- }
-
- private def tearDownMQTT() {
- if (broker != null) {
- broker.stop()
- broker = null
- }
- if (connector != null) {
- connector.stop()
- connector = null
- }
- }
-
- private def findFreePort(): Int = {
- val candidatePort = RandomUtils.nextInt(1024, 65536)
- Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
- val socket = new ServerSocket(trialPort)
- socket.close()
- (null, trialPort)
- }, new SparkConf())._2
- }
-
- def publishData(data: String): Unit = {
- var client: MqttClient = null
- try {
- val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
- client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence)
- client.connect()
- if (client.isConnected) {
- val msgTopic = client.getTopic(topic)
- val message = new MqttMessage(data.getBytes("utf-8"))
- message.setQos(1)
- message.setRetained(true)
-
- for (i <- 0 to 10) {
- try {
- msgTopic.publish(message)
- } catch {
- case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- // wait for Spark streaming to consume something from the message queue
- Thread.sleep(50)
- }
- }
- }
- } finally {
- client.disconnect()
- client.close()
- client = null
- }
- }
-
- /**
- * Block until at least one receiver has started or timeout occurs.
- */
- private def waitForReceiverToStart() = {
- val latch = new CountDownLatch(1)
- ssc.addStreamingListener(new StreamingListener {
- override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
- latch.countDown()
- }
- })
-
- assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
- }
}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
new file mode 100644
index 0000000000000..1a371b7008824
--- /dev/null
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.mqtt
+
+import java.net.{ServerSocket, URI}
+
+import scala.language.postfixOps
+
+import com.google.common.base.Charsets.UTF_8
+import org.apache.activemq.broker.{BrokerService, TransportConnector}
+import org.apache.commons.lang3.RandomUtils
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+
+import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf}
+
+/**
+ * Share codes for Scala and Python unit tests
+ */
+private class MQTTTestUtils extends Logging {
+
+ private val persistenceDir = Utils.createTempDir()
+ private val brokerHost = "localhost"
+ private val brokerPort = findFreePort()
+
+ private var broker: BrokerService = _
+ private var connector: TransportConnector = _
+
+ def brokerUri: String = {
+ s"$brokerHost:$brokerPort"
+ }
+
+ def setup(): Unit = {
+ broker = new BrokerService()
+ broker.setDataDirectoryFile(Utils.createTempDir())
+ connector = new TransportConnector()
+ connector.setName("mqtt")
+ connector.setUri(new URI("mqtt://" + brokerUri))
+ broker.addConnector(connector)
+ broker.start()
+ }
+
+ def teardown(): Unit = {
+ if (broker != null) {
+ broker.stop()
+ broker = null
+ }
+ if (connector != null) {
+ connector.stop()
+ connector = null
+ }
+ Utils.deleteRecursively(persistenceDir)
+ }
+
+ private def findFreePort(): Int = {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ }, new SparkConf())._2
+ }
+
+ def publishData(topic: String, data: String): Unit = {
+ var client: MqttClient = null
+ try {
+ val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
+ client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence)
+ client.connect()
+ if (client.isConnected) {
+ val msgTopic = client.getTopic(topic)
+ val message = new MqttMessage(data.getBytes(UTF_8))
+ message.setQos(1)
+ message.setRetained(true)
+
+ for (i <- 0 to 10) {
+ try {
+ msgTopic.publish(message)
+ } catch {
+ case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
+ // wait for Spark streaming to consume something from the message queue
+ Thread.sleep(50)
+ }
+ }
+ }
+ } finally {
+ if (client != null) {
+ client.disconnect()
+ client.close()
+ client = null
+ }
+ }
+ }
+
+}
diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml
new file mode 100644
index 0000000000000..70d2c9c58f54e
--- /dev/null
+++ b/extras/kinesis-asl-assembly/pom.xml
@@ -0,0 +1,103 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.10
+ 1.5.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-streaming-kinesis-asl-assembly_2.10
+ jar
+ Spark Project Kinesis Assembly
+ http://spark.apache.org/
+
+
+ streaming-kinesis-asl-assembly
+
+
+
+
+ org.apache.spark
+ spark-streaming-kinesis-asl_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+ ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar
+
+
+ *:*
+
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+ package
+
+ shade
+
+
+
+
+
+ reference.conf
+
+
+ log4j.properties
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index c242e7a57b9ab..521b53e230c4a 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -31,7 +31,7 @@
Spark Kinesis Integration
- kinesis-asl
+ streaming-kinesis-asl
diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py
new file mode 100644
index 0000000000000..f428f64da3c42
--- /dev/null
+++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py
@@ -0,0 +1,81 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Consumes messages from a Amazon Kinesis streams and does wordcount.
+
+ This example spins up 1 Kinesis Receiver per shard for the given stream.
+ It then starts pulling from the last checkpointed sequence number of the given stream.
+
+ Usage: kinesis_wordcount_asl.py
+ is the name of the consumer app, used to track the read data in DynamoDB
+ name of the Kinesis stream (ie. mySparkStream)
+ endpoint of the Kinesis service
+ (e.g. https://kinesis.us-east-1.amazonaws.com)
+
+
+ Example:
+ # export AWS keys if necessary
+ $ export AWS_ACCESS_KEY_ID=
+ $ export AWS_SECRET_KEY=
+
+ # run the example
+ $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\
+ spark-streaming-kinesis-asl-assembly_*.jar \
+ extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \
+ myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com
+
+ There is a companion helper class called KinesisWordProducerASL which puts dummy data
+ onto the Kinesis stream.
+
+ This code uses the DefaultAWSCredentialsProviderChain to find credentials
+ in the following order:
+ Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
+ Java System Properties - aws.accessKeyId and aws.secretKey
+ Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
+ Instance profile credentials - delivered through the Amazon EC2 metadata service
+ For more information, see
+ http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html
+
+ See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on
+ the Kinesis Spark Streaming integration.
+"""
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
+
+if __name__ == "__main__":
+ if len(sys.argv) != 5:
+ print(
+ "Usage: kinesis_wordcount_asl.py ",
+ file=sys.stderr)
+ sys.exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl")
+ ssc = StreamingContext(sc, 1)
+ appName, streamName, endpointUrl, regionName = sys.argv[1:]
+ lines = KinesisUtils.createStream(
+ ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2)
+ counts = lines.flatMap(lambda line: line.split(" ")) \
+ .map(lambda word: (word, 1)) \
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
index 8f144a4d974a8..a003ddf325e6e 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
@@ -37,16 +37,18 @@ case class SequenceNumberRange(
/** Class representing an array of Kinesis sequence number ranges */
private[kinesis]
-case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) {
+case class SequenceNumberRanges(ranges: Seq[SequenceNumberRange]) {
def isEmpty(): Boolean = ranges.isEmpty
+
def nonEmpty(): Boolean = ranges.nonEmpty
+
override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")")
}
private[kinesis]
object SequenceNumberRanges {
def apply(range: SequenceNumberRange): SequenceNumberRanges = {
- new SequenceNumberRanges(Array(range))
+ new SequenceNumberRanges(Seq(range))
}
}
@@ -66,14 +68,14 @@ class KinesisBackedBlockRDDPartition(
*/
private[kinesis]
class KinesisBackedBlockRDD(
- sc: SparkContext,
- regionId: String,
- endpointUrl: String,
+ @transient sc: SparkContext,
+ val regionName: String,
+ val endpointUrl: String,
@transient blockIds: Array[BlockId],
- @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges],
+ @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient isBlockIdValid: Array[Boolean] = Array.empty,
- retryTimeoutMs: Int = 10000,
- awsCredentialsOption: Option[SerializableAWSCredentials] = None
+ val retryTimeoutMs: Int = 10000,
+ val awsCredentialsOption: Option[SerializableAWSCredentials] = None
) extends BlockRDD[Array[Byte]](sc, blockIds) {
require(blockIds.length == arrayOfseqNumberRanges.length,
@@ -104,7 +106,7 @@ class KinesisBackedBlockRDD(
}
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
new KinesisSequenceRangeIterator(
- credenentials, endpointUrl, regionId, range, retryTimeoutMs)
+ credenentials, endpointUrl, regionName, range, retryTimeoutMs)
}
}
if (partition.isBlockIdValid) {
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
new file mode 100644
index 0000000000000..2e4204dcb6f1a
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.streaming.receiver.Receiver
+import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+import org.apache.spark.streaming.{Duration, StreamingContext, Time}
+
+private[kinesis] class KinesisInputDStream(
+ @transient _ssc: StreamingContext,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: InitialPositionInStream,
+ checkpointAppName: String,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ awsCredentialsOption: Option[SerializableAWSCredentials]
+ ) extends ReceiverInputDStream[Array[Byte]](_ssc) {
+
+ private[streaming]
+ override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = {
+
+ // This returns true even for when blockInfos is empty
+ val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty)
+
+ if (allBlocksHaveRanges) {
+ // Create a KinesisBackedBlockRDD, even when there are no blocks
+ val blockIds = blockInfos.map { _.blockId.asInstanceOf[BlockId] }.toArray
+ val seqNumRanges = blockInfos.map {
+ _.metadataOption.get.asInstanceOf[SequenceNumberRanges] }.toArray
+ val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray
+ logDebug(s"Creating KinesisBackedBlockRDD for $time with ${seqNumRanges.length} " +
+ s"seq number ranges: ${seqNumRanges.mkString(", ")} ")
+ new KinesisBackedBlockRDD(
+ context.sc, regionName, endpointUrl, blockIds, seqNumRanges,
+ isBlockIdValid = isBlockIdValid,
+ retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
+ awsCredentialsOption = awsCredentialsOption)
+ } else {
+ logWarning("Kinesis sequence number information was not present with some block metadata," +
+ " it may not be possible to recover from failures")
+ super.createBlockRDD(time, blockInfos)
+ }
+ }
+
+ override def getReceiver(): Receiver[Array[Byte]] = {
+ new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
+ checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption)
+ }
+}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 1a8a4cecc1141..22324e821ce94 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -18,17 +18,20 @@ package org.apache.spark.streaming.kinesis
import java.util.UUID
+import scala.collection.JavaConversions.asScalaIterator
+import scala.collection.mutable
import scala.util.control.NonFatal
-import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain}
+import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain}
import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory}
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker}
+import com.amazonaws.services.kinesis.model.Record
-import org.apache.spark.Logging
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.Duration
-import org.apache.spark.streaming.receiver.Receiver
+import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkEnv}
private[kinesis]
@@ -42,38 +45,47 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver.
* This implementation relies on the Kinesis Client Library (KCL) Worker as described here:
* https://github.com/awslabs/amazon-kinesis-client
- * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here:
- * http://spark.apache.org/docs/latest/streaming-custom-receivers.html
- * Instances of this class will get shipped to the Spark Streaming Workers to run within a
- * Spark Executor.
*
- * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams
- * by the Kinesis Client Library. If you change the App name or Stream name,
- * the KCL will throw errors. This usually requires deleting the backing
- * DynamoDB table with the same name this Kinesis application.
+ * The way this Receiver works is as follows:
+ * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple
+ * KinesisRecordProcessor
+ * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is
+ * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded.
+ * - When the block generator defines a block, then the recorded sequence number ranges that were
+ * inserted into the block are recorded separately for being used later.
+ * - When the block is ready to be pushed, the block is pushed and the ranges are reported as
+ * metadata of the block. In addition, the ranges are used to find out the latest sequence
+ * number for each shard that can be checkpointed through the DynamoDB.
+ * - Periodically, each KinesisRecordProcessor checkpoints the latest successfully stored sequence
+ * number for it own shard.
+ *
* @param streamName Kinesis stream name
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Region name used by the Kinesis Client Library for
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointAppName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams
+ * by the Kinesis Client Library. If you change the App name or Stream name,
+ * the KCL will throw errors. This usually requires deleting the backing
+ * DynamoDB table with the same name this Kinesis application.
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects
* @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
* the credentials
*/
private[kinesis] class KinesisReceiver(
- appName: String,
- streamName: String,
+ val streamName: String,
endpointUrl: String,
regionName: String,
initialPositionInStream: InitialPositionInStream,
+ checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsCredentialsOption: Option[SerializableAWSCredentials]
@@ -90,7 +102,7 @@ private[kinesis] class KinesisReceiver(
* workerId is used by the KCL should be based on the ip address of the actual Spark Worker
* where this code runs (not the driver's IP address.)
*/
- private var workerId: String = null
+ @volatile private var workerId: String = null
/**
* Worker is the core client abstraction from the Kinesis Client Library (KCL).
@@ -98,22 +110,40 @@ private[kinesis] class KinesisReceiver(
* Each shard is assigned its own IRecordProcessor and the worker run multiple such
* processors.
*/
- private var worker: Worker = null
+ @volatile private var worker: Worker = null
+ @volatile private var workerThread: Thread = null
- /** Thread running the worker */
- private var workerThread: Thread = null
+ /** BlockGenerator used to generates blocks out of Kinesis data */
+ @volatile private var blockGenerator: BlockGenerator = null
+ /**
+ * Sequence number ranges added to the current block being generated.
+ * Accessing and updating of this map is synchronized by locks in BlockGenerator.
+ */
+ private val seqNumRangesInCurrentBlock = new mutable.ArrayBuffer[SequenceNumberRange]
+
+ /** Sequence number ranges of data added to each generated block */
+ private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges]
+ with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges]
+
+ /**
+ * Latest sequence number ranges that have been stored successfully.
+ * This is used for checkpointing through KCL */
+ private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String]
+ with mutable.SynchronizedMap[String, String]
/**
* This is called when the KinesisReceiver starts and must be non-blocking.
* The KCL creates and manages the receiving/processing thread pool through Worker.run().
*/
override def onStart() {
+ blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler)
+
workerId = Utils.localHostName() + ":" + UUID.randomUUID()
// KCL config instance
val awsCredProvider = resolveAWSCredentialsProvider()
val kinesisClientLibConfiguration =
- new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId)
+ new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId)
.withKinesisEndpoint(endpointUrl)
.withInitialPositionInStream(initialPositionInStream)
.withTaskBackoffTimeMillis(500)
@@ -141,6 +171,10 @@ private[kinesis] class KinesisReceiver(
}
}
}
+
+ blockIdToSeqNumRanges.clear()
+ blockGenerator.start()
+
workerThread.setName(s"Kinesis Receiver ${streamId}")
workerThread.setDaemon(true)
workerThread.start()
@@ -165,6 +199,81 @@ private[kinesis] class KinesisReceiver(
workerId = null
}
+ /** Add records of the given shard to the current block being generated */
+ private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = {
+ if (records.size > 0) {
+ val dataIterator = records.iterator().map { record =>
+ val byteBuffer = record.getData()
+ val byteArray = new Array[Byte](byteBuffer.remaining())
+ byteBuffer.get(byteArray)
+ byteArray
+ }
+ val metadata = SequenceNumberRange(streamName, shardId,
+ records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber())
+ blockGenerator.addMultipleDataWithCallback(dataIterator, metadata)
+
+ }
+ }
+
+ /** Get the latest sequence number for the given shard that can be checkpointed through KCL */
+ private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = {
+ shardIdToLatestStoredSeqNum.get(shardId)
+ }
+
+ /**
+ * Remember the range of sequence numbers that was added to the currently active block.
+ * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`.
+ */
+ private def rememberAddedRange(range: SequenceNumberRange): Unit = {
+ seqNumRangesInCurrentBlock += range
+ }
+
+ /**
+ * Finalize the ranges added to the block that was active and prepare the ranges buffer
+ * for next block. Internally, this is synchronized with `rememberAddedRange()`.
+ */
+ private def finalizeRangesForCurrentBlock(blockId: StreamBlockId): Unit = {
+ blockIdToSeqNumRanges(blockId) = SequenceNumberRanges(seqNumRangesInCurrentBlock.toArray)
+ seqNumRangesInCurrentBlock.clear()
+ logDebug(s"Generated block $blockId has $blockIdToSeqNumRanges")
+ }
+
+ /** Store the block along with its associated ranges */
+ private def storeBlockWithRanges(
+ blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = {
+ val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId)
+ if (rangesToReportOption.isEmpty) {
+ stop("Error while storing block into Spark, could not find sequence number ranges " +
+ s"for block $blockId")
+ return
+ }
+
+ val rangesToReport = rangesToReportOption.get
+ var attempt = 0
+ var stored = false
+ var throwable: Throwable = null
+ while (!stored && attempt <= 3) {
+ try {
+ store(arrayBuffer, rangesToReport)
+ stored = true
+ } catch {
+ case NonFatal(th) =>
+ attempt += 1
+ throwable = th
+ }
+ }
+ if (!stored) {
+ stop("Error while storing block into Spark", throwable)
+ }
+
+ // Update the latest sequence number that have been successfully stored for each shard
+ // Note that we are doing this sequentially because the array of sequence number ranges
+ // is assumed to be
+ rangesToReport.ranges.foreach { range =>
+ shardIdToLatestStoredSeqNum(range.shardId) = range.toSeqNumber
+ }
+ }
+
/**
* If AWS credential is provided, return a AWSCredentialProvider returning that credential.
* Otherwise, return the DefaultAWSCredentialsProviderChain.
@@ -182,4 +291,46 @@ private[kinesis] class KinesisReceiver(
new DefaultAWSCredentialsProviderChain()
}
}
+
+
+ /**
+ * Class to handle blocks generated by this receiver's block generator. Specifically, in
+ * the context of the Kinesis Receiver, this handler does the following.
+ *
+ * - When an array of records is added to the current active block in the block generator,
+ * this handler keeps track of the corresponding sequence number range.
+ * - When the currently active block is ready to sealed (not more records), this handler
+ * keep track of the list of ranges added into this block in another H
+ */
+ private class GeneratedBlockHandler extends BlockGeneratorListener {
+
+ /**
+ * Callback method called after a data item is added into the BlockGenerator.
+ * The data addition, block generation, and calls to onAddData and onGenerateBlock
+ * are all synchronized through the same lock.
+ */
+ def onAddData(data: Any, metadata: Any): Unit = {
+ rememberAddedRange(metadata.asInstanceOf[SequenceNumberRange])
+ }
+
+ /**
+ * Callback method called after a block has been generated.
+ * The data addition, block generation, and calls to onAddData and onGenerateBlock
+ * are all synchronized through the same lock.
+ */
+ def onGenerateBlock(blockId: StreamBlockId): Unit = {
+ finalizeRangesForCurrentBlock(blockId)
+ }
+
+ /** Callback method called when a block is ready to be pushed / stored. */
+ def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
+ storeBlockWithRanges(blockId,
+ arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]])
+ }
+
+ /** Callback called in case of any error in internal of the BlockGenerator */
+ def onError(message: String, throwable: Throwable): Unit = {
+ reportError(message, throwable)
+ }
+ }
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index fe9e3a0c793e2..b2405123321e3 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -18,20 +18,16 @@ package org.apache.spark.streaming.kinesis
import java.util.List
-import scala.collection.JavaConversions.asScalaBuffer
import scala.util.Random
+import scala.util.control.NonFatal
-import org.apache.spark.Logging
-
-import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException
-import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException
-import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException
-import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException
-import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor
-import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException}
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer}
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.model.Record
+import org.apache.spark.Logging
+
/**
* Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor.
* This implementation operates on the Array[Byte] from the KinesisReceiver.
@@ -51,6 +47,7 @@ private[kinesis] class KinesisRecordProcessor(
checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging {
// shardId to be populated during initialize()
+ @volatile
private var shardId: String = _
/**
@@ -75,47 +72,38 @@ private[kinesis] class KinesisRecordProcessor(
override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) {
if (!receiver.isStopped()) {
try {
- /*
- * Notes:
- * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming
- * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the
- * internally-configured Spark serializer (kryo, etc).
- * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple
- * ourselves from Spark's internal serialization strategy.
- * 3) For performance, the BlockGenerator is asynchronously queuing elements within its
- * memory before creating blocks. This prevents the small block scenario, but requires
- * that you register callbacks to know when a block has been generated and stored
- * (WAL is sufficient for storage) before can checkpoint back to the source.
- */
- batch.foreach(record => receiver.store(record.getData().array()))
-
- logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
+ receiver.addRecords(shardId, batch)
+ logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
/*
- * Checkpoint the sequence number of the last record successfully processed/stored
- * in the batch.
- * In this implementation, we're checkpointing after the given checkpointIntervalMillis.
- * Note that this logic requires that processRecords() be called AND that it's time to
- * checkpoint. I point this out because there is no background thread running the
- * checkpointer. Checkpointing is tested and trigger only when a new batch comes in.
- * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below).
- * However, if the worker dies unexpectedly, a checkpoint may not happen.
- * This could lead to records being processed more than once.
+ *
+ * Checkpoint the sequence number of the last record successfully stored.
+ * Note that in this current implementation, the checkpointing occurs only when after
+ * checkpointIntervalMillis from the last checkpoint, AND when there is new record
+ * to process. This leads to the checkpointing lagging behind what records have been
+ * stored by the receiver. Ofcourse, this can lead records processed more than once,
+ * under failures and restarts.
+ *
+ * TODO: Instead of checkpointing here, run a separate timer task to perform
+ * checkpointing so that it checkpoints in a timely manner independent of whether
+ * new records are available or not.
*/
if (checkpointState.shouldCheckpoint()) {
- /* Perform the checkpoint */
- KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
+ receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum =>
+ /* Perform the checkpoint */
+ KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100)
- /* Update the next checkpoint time */
- checkpointState.advanceCheckpoint()
+ /* Update the next checkpoint time */
+ checkpointState.advanceCheckpoint()
- logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" +
+ logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" +
s" records for shardId $shardId")
- logDebug(s"Checkpoint: Next checkpoint is at " +
+ logDebug(s"Checkpoint: Next checkpoint is at " +
s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId")
+ }
}
} catch {
- case e: Throwable => {
+ case NonFatal(e) => {
/*
* If there is a failure within the batch, the batch will not be checkpointed.
* This will potentially cause records since the last checkpoint to be processed
@@ -130,7 +118,7 @@ private[kinesis] class KinesisRecordProcessor(
}
} else {
/* RecordProcessor has been stopped. */
- logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" +
+ logInfo(s"Stopped: KinesisReceiver has stopped for workerId $workerId" +
s" and shardId $shardId. No more records will be processed.")
}
}
@@ -154,7 +142,11 @@ private[kinesis] class KinesisRecordProcessor(
* It's now OK to read from the new shards that resulted from a resharding event.
*/
case ShutdownReason.TERMINATE =>
- KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
+ val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId)
+ if (latestSeqNumToCheckpointOption.nonEmpty) {
+ KinesisRecordProcessor.retryRandom(
+ checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100)
+ }
/*
* ZOMBIE Use Case. NoOp.
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index 0ff1b7ed0fd90..711aade182945 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -36,16 +36,10 @@ import org.apache.spark.Logging
/**
* Shared utility methods for performing Kinesis tests that actually transfer data
*/
-private class KinesisTestUtils(
- val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com",
- _regionName: String = "") extends Logging {
-
- val regionName = if (_regionName.length == 0) {
- RegionUtils.getRegionByEndpoint(endpointUrl).getName()
- } else {
- RegionUtils.getRegion(_regionName).getName()
- }
+private class KinesisTestUtils extends Logging {
+ val endpointUrl = KinesisTestUtils.endpointUrl
+ val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
val streamShardCount = 2
private val createStreamTimeoutSeconds = 300
@@ -53,6 +47,8 @@ private class KinesisTestUtils(
@volatile
private var streamCreated = false
+
+ @volatile
private var _streamName: String = _
private lazy val kinesisClient = {
@@ -73,11 +69,11 @@ private class KinesisTestUtils(
}
def createStream(): Unit = {
- logInfo("Creating stream")
require(!streamCreated, "Stream already created")
_streamName = findNonExistentStreamName()
// Create a stream. The number of shards determines the provisioned throughput.
+ logInfo(s"Creating stream ${_streamName}")
val createStreamRequest = new CreateStreamRequest()
createStreamRequest.setStreamName(_streamName)
createStreamRequest.setShardCount(2)
@@ -86,7 +82,7 @@ private class KinesisTestUtils(
// The stream is now being created. Wait for it to become active.
waitForStreamToBeActive(_streamName)
streamCreated = true
- logInfo("Created stream")
+ logInfo(s"Created stream ${_streamName}")
}
/**
@@ -115,21 +111,16 @@ private class KinesisTestUtils(
shardIdToSeqNumbers.toMap
}
- def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = {
- try {
- val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe)
- val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription()
- Some(desc)
- } catch {
- case rnfe: ResourceNotFoundException =>
- None
- }
+ /**
+ * Expose a Python friendly API.
+ */
+ def pushData(testData: java.util.List[Int]): Unit = {
+ pushData(scala.collection.JavaConversions.asScalaBuffer(testData))
}
def deleteStream(): Unit = {
try {
- if (describeStream().nonEmpty) {
- val deleteStreamRequest = new DeleteStreamRequest()
+ if (streamCreated) {
kinesisClient.deleteStream(streamName)
}
} catch {
@@ -149,6 +140,17 @@ private class KinesisTestUtils(
}
}
+ private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = {
+ try {
+ val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe)
+ val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription()
+ Some(desc)
+ } catch {
+ case rnfe: ResourceNotFoundException =>
+ None
+ }
+ }
+
private def findNonExistentStreamName(): String = {
var testStreamName: String = null
do {
@@ -177,9 +179,38 @@ private class KinesisTestUtils(
private[kinesis] object KinesisTestUtils {
- val envVarName = "ENABLE_KINESIS_TESTS"
+ val envVarNameForEnablingTests = "ENABLE_KINESIS_TESTS"
+ val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL"
+ val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com"
+
+ lazy val shouldRunTests = {
+ val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1")
+ if (isEnvSet) {
+ // scalastyle:off println
+ // Print this so that they are easily visible on the console and not hidden in the log4j logs.
+ println(
+ s"""
+ |Kinesis tests that actually send data has been enabled by setting the environment
+ |variable $envVarNameForEnablingTests to 1. This will create Kinesis Streams and
+ |DynamoDB tables in AWS. Please be aware that this may incur some AWS costs.
+ |By default, the tests use the endpoint URL $defaultEndpointUrl to create Kinesis streams.
+ |To change this endpoint URL to a different region, you can set the environment variable
+ |$endVarNameForEndpoint to the desired endpoint URL
+ |(e.g. $endVarNameForEndpoint="https://kinesis.us-west-2.amazonaws.com").
+ """.stripMargin)
+ // scalastyle:on println
+ }
+ isEnvSet
+ }
- val shouldRunTests = sys.env.get(envVarName) == Some("1")
+ lazy val endpointUrl = {
+ val url = sys.env.getOrElse(endVarNameForEndpoint, defaultEndpointUrl)
+ // scalastyle:off println
+ // Print this so that they are easily visible on the console and not hidden in the log4j logs.
+ println(s"Using endpoint URL $url for creating Kinesis streams for tests.")
+ // scalastyle:on println
+ url
+ }
def isAWSCredentialsPresent: Boolean = {
Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess
@@ -191,7 +222,13 @@ private[kinesis] object KinesisTestUtils {
Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match {
case Success(cred) => cred
case Failure(e) =>
- throw new Exception("Kinesis tests enabled, but could get not AWS credentials")
+ throw new Exception(
+ s"""
+ |Kinesis tests enabled using environment variable $envVarNameForEnablingTests
+ |but could not find AWS credentials. Please follow instructions in AWS documentation
+ |to set the credentials in your system such that the DefaultAWSCredentialsProviderChain
+ |can find the credentials.
+ """.stripMargin)
}
}
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index e5acab50181e1..c799fadf2d5ce 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -65,9 +65,8 @@ object KinesisUtils {
): ReceiverInputDStream[Array[Byte]] = {
// Setting scope to override receiver stream's scope of "receiver stream"
ssc.withNamedScope("kinesis stream") {
- ssc.receiverStream(
- new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName),
- initialPositionInStream, checkpointInterval, storageLevel, None))
+ new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
+ initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None)
}
}
@@ -86,19 +85,19 @@ object KinesisUtils {
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
+ * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+ * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
*/
def createStream(
ssc: StreamingContext,
@@ -112,10 +111,11 @@ object KinesisUtils {
awsAccessKeyId: String,
awsSecretKey: String
): ReceiverInputDStream[Array[Byte]] = {
- ssc.receiverStream(
- new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName),
- initialPositionInStream, checkpointInterval, storageLevel,
- Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))))
+ ssc.withNamedScope("kinesis stream") {
+ new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
+ initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+ Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+ }
}
/**
@@ -130,7 +130,7 @@ object KinesisUtils {
* - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in
* [[org.apache.spark.SparkConf]].
*
- * @param ssc Java StreamingContext object
+ * @param ssc StreamingContext object
* @param streamName Kinesis stream name
* @param endpointUrl Endpoint url of Kinesis service
* (e.g., https://kinesis.us-east-1.amazonaws.com)
@@ -155,9 +155,10 @@ object KinesisUtils {
initialPositionInStream: InitialPositionInStream,
storageLevel: StorageLevel
): ReceiverInputDStream[Array[Byte]] = {
- ssc.receiverStream(
- new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl),
- initialPositionInStream, checkpointInterval, storageLevel, None))
+ ssc.withNamedScope("kinesis stream") {
+ new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl),
+ initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None)
+ }
}
/**
@@ -175,15 +176,15 @@ object KinesisUtils {
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
*/
@@ -206,8 +207,8 @@ object KinesisUtils {
* This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
*
* Note:
- * The given AWS credentials will get saved in DStream checkpoints if checkpointing
- * is enabled. Make sure that your checkpoint directory is secure.
+ * The given AWS credentials will get saved in DStream checkpoints if checkpointing
+ * is enabled. Make sure that your checkpoint directory is secure.
*
* @param jssc Java StreamingContext object
* @param kinesisAppName Kinesis application name used by the Kinesis Client Library
@@ -216,19 +217,19 @@ object KinesisUtils {
* @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
* @param regionName Name of region used by the Kinesis Client Library (KCL) to update
* DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
- * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
- * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
- * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
- * See the Kinesis Spark Streaming documentation for more
- * details on the different types of checkpoints.
* @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
* worker's initial starting position in the stream.
* The values are either the beginning of the stream
* per Kinesis' limit of 24 hours
* (InitialPositionInStream.TRIM_HORIZON) or
* the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
+ * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+ * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
*/
def createStream(
jssc: JavaStreamingContext,
@@ -297,3 +298,49 @@ object KinesisUtils {
}
}
}
+
+/**
+ * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and
+ * function so that it can be easily instantiated and called from Python's KinesisUtils.
+ */
+private class KinesisUtilsPythonHelper {
+
+ def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = {
+ initialPositionInStream match {
+ case 0 => InitialPositionInStream.LATEST
+ case 1 => InitialPositionInStream.TRIM_HORIZON
+ case _ => throw new IllegalArgumentException(
+ "Illegal InitialPositionInStream. Please use " +
+ "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON")
+ }
+ }
+
+ def createStream(
+ jssc: JavaStreamingContext,
+ kinesisAppName: String,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: Int,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ awsAccessKeyId: String,
+ awsSecretKey: String
+ ): JavaReceiverInputDStream[Array[Byte]] = {
+ if (awsAccessKeyId == null && awsSecretKey != null) {
+ throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null")
+ }
+ if (awsAccessKeyId != null && awsSecretKey == null) {
+ throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null")
+ }
+ if (awsAccessKeyId == null && awsSecretKey == null) {
+ KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
+ getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel)
+ } else {
+ KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
+ getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
+ awsAccessKeyId, awsSecretKey)
+ }
+ }
+
+}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index e81fb11e5959f..a89e5627e014c 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -24,8 +24,6 @@ import org.apache.spark.{SparkConf, SparkContext, SparkException}
class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll {
- private val regionId = "us-east-1"
- private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com"
private val testData = 1 to 8
private var testUtils: KinesisTestUtils = null
@@ -42,7 +40,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
override def beforeAll(): Unit = {
runIfTestsEnabled("Prepare KinesisTestUtils") {
- testUtils = new KinesisTestUtils(endpointUrl)
+ testUtils = new KinesisTestUtils()
testUtils.createStream()
shardIdToDataAndSeqNumbers = testUtils.pushData(testData)
@@ -75,21 +73,21 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
testIfEnabled("Basic reading from Kinesis") {
// Verify all data using multiple ranges in a single RDD partition
- val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl,
+ val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
fakeBlockIds(1),
Array(SequenceNumberRanges(allRanges.toArray))
).map { bytes => new String(bytes).toInt }.collect()
assert(receivedData1.toSet === testData.toSet)
// Verify all data using one range in each of the multiple RDD partitions
- val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl,
+ val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
fakeBlockIds(allRanges.size),
allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
).map { bytes => new String(bytes).toInt }.collect()
assert(receivedData2.toSet === testData.toSet)
// Verify ordering within each partition
- val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl,
+ val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
fakeBlockIds(allRanges.size),
allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
).map { bytes => new String(bytes).toInt }.collectPartitions()
@@ -211,7 +209,8 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
}, "Incorrect configuration of RDD, unexpected ranges set"
)
- val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges)
+ val rdd = new KinesisBackedBlockRDD(
+ sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges)
val collectedData = rdd.map { bytes =>
new String(bytes).toInt
}.collect()
@@ -224,8 +223,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
if (testIsBlockValid) {
require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager")
require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis")
- val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray,
- ranges, isBlockIdValid = Array.fill(blockIds.length)(false))
+ val rdd2 = new KinesisBackedBlockRDD(
+ sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges,
+ isBlockIdValid = Array.fill(blockIds.length)(false))
intercept[SparkException] {
rdd2.collect()
}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
index 8373138785a89..ee428f31d6ce3 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
@@ -31,7 +31,7 @@ trait KinesisFunSuite extends SparkFunSuite {
if (shouldRunTests) {
test(testName)(testBody)
} else {
- ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody)
+ ignore(s"$testName [enable by setting env var $envVarNameForEnablingTests=1]")(testBody)
}
}
@@ -40,7 +40,7 @@ trait KinesisFunSuite extends SparkFunSuite {
if (shouldRunTests) {
body
} else {
- ignore(s"$message [enable by setting env var $envVarName=1]")()
+ ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")()
}
}
}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 98f2c7c4f1bfb..ceb135e0651aa 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -22,15 +22,14 @@ import scala.collection.JavaConversions.seqAsJavaList
import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException}
import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
-import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.model.Record
+import org.mockito.Matchers._
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.mock.MockitoSugar
+import org.scalatest.{BeforeAndAfter, Matchers}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase}
+import org.apache.spark.streaming.{Milliseconds, TestSuiteBase}
import org.apache.spark.util.{Clock, ManualClock, Utils}
/**
@@ -44,6 +43,8 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
val endpoint = "endpoint-url"
val workerId = "dummyWorkerId"
val shardId = "dummyShardId"
+ val seqNum = "dummySeqNum"
+ val someSeqNum = Some(seqNum)
val record1 = new Record()
record1.setData(ByteBuffer.wrap("Spark In Action".getBytes()))
@@ -80,16 +81,18 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
test("process records including store and checkpoint") {
when(receiverMock.isStopped()).thenReturn(false)
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
when(checkpointStateMock.shouldCheckpoint()).thenReturn(true)
val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.initialize(shardId)
recordProcessor.processRecords(batch, checkpointerMock)
verify(receiverMock, times(1)).isStopped()
- verify(receiverMock, times(1)).store(record1.getData().array())
- verify(receiverMock, times(1)).store(record2.getData().array())
+ verify(receiverMock, times(1)).addRecords(shardId, batch)
+ verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId)
verify(checkpointStateMock, times(1)).shouldCheckpoint()
- verify(checkpointerMock, times(1)).checkpoint()
+ verify(checkpointerMock, times(1)).checkpoint(anyString)
verify(checkpointStateMock, times(1)).advanceCheckpoint()
}
@@ -100,19 +103,25 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
recordProcessor.processRecords(batch, checkpointerMock)
verify(receiverMock, times(1)).isStopped()
+ verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record]))
+ verify(checkpointerMock, never).checkpoint(anyString)
}
test("shouldn't checkpoint when exception occurs during store") {
when(receiverMock.isStopped()).thenReturn(false)
- when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException())
+ when(
+ receiverMock.addRecords(anyString, anyListOf(classOf[Record]))
+ ).thenThrow(new RuntimeException())
intercept[RuntimeException] {
val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.initialize(shardId)
recordProcessor.processRecords(batch, checkpointerMock)
}
verify(receiverMock, times(1)).isStopped()
- verify(receiverMock, times(1)).store(record1.getData().array())
+ verify(receiverMock, times(1)).addRecords(shardId, batch)
+ verify(checkpointerMock, never).checkpoint(anyString)
}
test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") {
@@ -158,19 +167,25 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
}
test("shutdown should checkpoint if the reason is TERMINATE") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
- val reason = ShutdownReason.TERMINATE
- recordProcessor.shutdown(checkpointerMock, reason)
+ recordProcessor.initialize(shardId)
+ recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE)
- verify(checkpointerMock, times(1)).checkpoint()
+ verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId)
+ verify(checkpointerMock, times(1)).checkpoint(anyString)
}
test("shutdown should not checkpoint if the reason is something other than TERMINATE") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.initialize(shardId)
recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE)
recordProcessor.shutdown(checkpointerMock, null)
- verify(checkpointerMock, never()).checkpoint()
+ verify(checkpointerMock, never).checkpoint(anyString)
}
test("retry success on first attempt") {
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index f9c952b9468bb..1177dc758100d 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -22,34 +22,67 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
+import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.streaming.kinesis.KinesisTestUtils._
+import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
+import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkContext}
class KinesisStreamSuite extends KinesisFunSuite
with Eventually with BeforeAndAfter with BeforeAndAfterAll {
- // This is the name that KCL uses to save metadata to DynamoDB
- private val kinesisAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
+ // This is the name that KCL will use to save metadata to DynamoDB
+ private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
+ private val batchDuration = Seconds(1)
- private var ssc: StreamingContext = _
- private var sc: SparkContext = _
+ // Dummy parameters for API testing
+ private val dummyEndpointUrl = defaultEndpointUrl
+ private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName()
+ private val dummyAWSAccessKey = "dummyAccessKey"
+ private val dummyAWSSecretKey = "dummySecretKey"
+
+ private var testUtils: KinesisTestUtils = null
+ private var ssc: StreamingContext = null
+ private var sc: SparkContext = null
override def beforeAll(): Unit = {
val conf = new SparkConf()
.setMaster("local[4]")
.setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
sc = new SparkContext(conf)
+
+ runIfTestsEnabled("Prepare KinesisTestUtils") {
+ testUtils = new KinesisTestUtils()
+ testUtils.createStream()
+ }
}
override def afterAll(): Unit = {
- sc.stop()
- // Delete the Kinesis stream as well as the DynamoDB table generated by
- // Kinesis Client Library when consuming the stream
+ if (ssc != null) {
+ ssc.stop()
+ }
+ if (sc != null) {
+ sc.stop()
+ }
+ if (testUtils != null) {
+ // Delete the Kinesis stream as well as the DynamoDB table generated by
+ // Kinesis Client Library when consuming the stream
+ testUtils.deleteStream()
+ testUtils.deleteDynamoDBTable(appName)
+ }
+ }
+
+ before {
+ ssc = new StreamingContext(sc, batchDuration)
}
after {
@@ -57,21 +90,75 @@ class KinesisStreamSuite extends KinesisFunSuite
ssc.stop(stopSparkContext = false)
ssc = null
}
+ if (testUtils != null) {
+ testUtils.deleteDynamoDBTable(appName)
+ }
}
test("KinesisUtils API") {
- ssc = new StreamingContext(sc, Seconds(1))
// Tests the API, does not actually test data receiving
val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", Seconds(2),
+ dummyEndpointUrl, Seconds(2),
InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2)
val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
+ dummyEndpointUrl, dummyRegionName,
InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2)
val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", "us-west-2",
+ dummyEndpointUrl, dummyRegionName,
InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2,
- "awsAccessKey", "awsSecretKey")
+ dummyAWSAccessKey, dummyAWSSecretKey)
+ }
+
+ test("RDD generation") {
+ val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream",
+ dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2),
+ StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey)
+ assert(inputStream.isInstanceOf[KinesisInputDStream])
+
+ val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream]
+ val time = Time(1000)
+
+ // Generate block info data for testing
+ val seqNumRanges1 = SequenceNumberRanges(
+ SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))
+ val blockId1 = StreamBlockId(kinesisStream.id, 123)
+ val blockInfo1 = ReceivedBlockInfo(
+ 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None))
+
+ val seqNumRanges2 = SequenceNumberRanges(
+ SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb"))
+ val blockId2 = StreamBlockId(kinesisStream.id, 345)
+ val blockInfo2 = ReceivedBlockInfo(
+ 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None))
+
+ // Verify that the generated KinesisBackedBlockRDD has the all the right information
+ val blockInfos = Seq(blockInfo1, blockInfo2)
+ val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos)
+ nonEmptyRDD shouldBe a [KinesisBackedBlockRDD]
+ val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD]
+ assert(kinesisRDD.regionName === dummyRegionName)
+ assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
+ assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
+ assert(kinesisRDD.awsCredentialsOption ===
+ Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey)))
+ assert(nonEmptyRDD.partitions.size === blockInfos.size)
+ nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] }
+ val partitions = nonEmptyRDD.partitions.map {
+ _.asInstanceOf[KinesisBackedBlockRDDPartition] }.toSeq
+ assert(partitions.map { _.seqNumberRanges } === Seq(seqNumRanges1, seqNumRanges2))
+ assert(partitions.map { _.blockId } === Seq(blockId1, blockId2))
+ assert(partitions.forall { _.isBlockIdValid === true })
+
+ // Verify that KinesisBackedBlockRDD is generated even when there are no blocks
+ val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty)
+ emptyRDD shouldBe a [KinesisBackedBlockRDD]
+ emptyRDD.partitions shouldBe empty
+
+ // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid
+ blockInfos.foreach { _.setBlockIdInvalid() }
+ kinesisStream.createBlockRDD(time, blockInfos).partitions.foreach { partition =>
+ assert(partition.asInstanceOf[KinesisBackedBlockRDDPartition].isBlockIdValid === false)
+ }
}
@@ -84,32 +171,91 @@ class KinesisStreamSuite extends KinesisFunSuite
* and you have to set the system environment variable RUN_KINESIS_TESTS=1 .
*/
testIfEnabled("basic operation") {
- val kinesisTestUtils = new KinesisTestUtils()
- try {
- kinesisTestUtils.createStream()
- ssc = new StreamingContext(sc, Seconds(1))
- val aWSCredentials = KinesisTestUtils.getAWSCredentials()
- val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName,
- kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST,
- Seconds(10), StorageLevel.MEMORY_ONLY,
- aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey)
-
- val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
- stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
- collected ++= rdd.collect()
- logInfo("Collected = " + rdd.collect().toSeq.mkString(", "))
- }
- ssc.start()
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+ testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
- val testData = 1 to 10
- eventually(timeout(120 seconds), interval(10 second)) {
- kinesisTestUtils.pushData(testData)
- assert(collected === testData.toSet, "\nData received does not match data sent")
+ val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
+ stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
+ collected ++= rdd.collect()
+ logInfo("Collected = " + rdd.collect().toSeq.mkString(", "))
+ }
+ ssc.start()
+
+ val testData = 1 to 10
+ eventually(timeout(120 seconds), interval(10 second)) {
+ testUtils.pushData(testData)
+ assert(collected === testData.toSet, "\nData received does not match data sent")
+ }
+ ssc.stop(stopSparkContext = false)
+ }
+
+ testIfEnabled("failure recovery") {
+ val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
+ val checkpointDir = Utils.createTempDir().getAbsolutePath
+
+ ssc = new StreamingContext(sc, Milliseconds(1000))
+ ssc.checkpoint(checkpointDir)
+
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])]
+ with mutable.SynchronizedMap[Time, (Array[SequenceNumberRanges], Seq[Int])]
+
+ val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+ testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+ // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch
+ kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => {
+ val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
+ val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq
+ collectedData(time) = (kRdd.arrayOfseqNumberRanges, data)
+ })
+
+ ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint
+ ssc.start()
+
+ def numBatchesWithData: Int = collectedData.count(_._2._2.nonEmpty)
+
+ def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty
+
+ // Run until there are at least 10 batches with some data in them
+ // If this times out because numBatchesWithData is empty, then its likely that foreachRDD
+ // function failed with exceptions, and nothing got added to `collectedData`
+ eventually(timeout(2 minutes), interval(1 seconds)) {
+ testUtils.pushData(1 to 5)
+ assert(isCheckpointPresent && numBatchesWithData > 10)
+ }
+ ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused
+
+ // Restart the context from checkpoint and verify whether the
+ logInfo("Restarting from checkpoint")
+ ssc = new StreamingContext(checkpointDir)
+ ssc.start()
+ val recoveredKinesisStream = ssc.graph.getInputStreams().head
+
+ // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges
+ // and return the same data
+ val times = collectedData.keySet
+ times.foreach { time =>
+ val (arrayOfSeqNumRanges, data) = collectedData(time)
+ val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]]
+ rdd shouldBe a [KinesisBackedBlockRDD]
+
+ // Verify the recovered sequence ranges
+ val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
+ assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size)
+ arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) =>
+ assert(expected.ranges.toSeq === found.ranges.toSeq)
}
- ssc.stop()
- } finally {
- kinesisTestUtils.deleteStream()
- kinesisTestUtils.deleteDynamoDBTable(kinesisAppName)
+
+ // Verify the recovered data
+ assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data)
}
+ ssc.stop()
}
+
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
index 33ac7b0ed6095..7f4e7e9d79d6b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
@@ -87,7 +87,7 @@ class VertexRDDImpl[VD] private[graphx] (
/** The number of vertices in the RDD. */
override def count(): Long = {
- partitionsRDD.map(_.size).reduce(_ + _)
+ partitionsRDD.map(_.size.toLong).reduce(_ + _)
}
override private[graphx] def mapVertexPartitions[VD2: ClassTag](
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
index de85720febf23..5f95e2c74f902 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
@@ -69,7 +69,8 @@ public List buildCommand(Map env) throws IOException {
} else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) {
javaOptsKeys.add("SPARK_EXECUTOR_OPTS");
memKey = "SPARK_EXECUTOR_MEMORY";
- } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) {
+ } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") ||
+ className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) {
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
javaOptsKeys.add("SPARK_SHUFFLE_OPTS");
memKey = "SPARK_DAEMON_MEMORY";
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
index c0f89c9230692..03c9358bc865d 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
@@ -20,12 +20,13 @@
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.apache.spark.launcher.CommandBuilderUtils.*;
-/**
+/**
* Launcher for Spark applications.
*
* Use this class to start Spark applications programmatically. The class uses a builder pattern
@@ -57,7 +58,8 @@ public class SparkLauncher {
/** Configuration key for the number of executor CPU cores. */
public static final String EXECUTOR_CORES = "spark.executor.cores";
- private final SparkSubmitCommandBuilder builder;
+ // Visible for testing.
+ final SparkSubmitCommandBuilder builder;
public SparkLauncher() {
this(null);
@@ -187,6 +189,73 @@ public SparkLauncher setMainClass(String mainClass) {
return this;
}
+ /**
+ * Adds a no-value argument to the Spark invocation. If the argument is known, this method
+ * validates whether the argument is indeed a no-value argument, and throws an exception
+ * otherwise.
+ *
+ * Use this method with caution. It is possible to create an invalid Spark command by passing
+ * unknown arguments to this method, since those are allowed for forward compatibility.
+ *
+ * @param arg Argument to add.
+ * @return This launcher.
+ */
+ public SparkLauncher addSparkArg(String arg) {
+ SparkSubmitOptionParser validator = new ArgumentValidator(false);
+ validator.parse(Arrays.asList(arg));
+ builder.sparkArgs.add(arg);
+ return this;
+ }
+
+ /**
+ * Adds an argument with a value to the Spark invocation. If the argument name corresponds to
+ * a known argument, the code validates that the argument actually expects a value, and throws
+ * an exception otherwise.
+ *
+ * It is safe to add arguments modified by other methods in this class (such as
+ * {@link #setMaster(String)} - the last invocation will be the one to take effect.
+ *
+ * Use this method with caution. It is possible to create an invalid Spark command by passing
+ * unknown arguments to this method, since those are allowed for forward compatibility.
+ *
+ * @param name Name of argument to add.
+ * @param value Value of the argument.
+ * @return This launcher.
+ */
+ public SparkLauncher addSparkArg(String name, String value) {
+ SparkSubmitOptionParser validator = new ArgumentValidator(true);
+ if (validator.MASTER.equals(name)) {
+ setMaster(value);
+ } else if (validator.PROPERTIES_FILE.equals(name)) {
+ setPropertiesFile(value);
+ } else if (validator.CONF.equals(name)) {
+ String[] vals = value.split("=", 2);
+ setConf(vals[0], vals[1]);
+ } else if (validator.CLASS.equals(name)) {
+ setMainClass(value);
+ } else if (validator.JARS.equals(name)) {
+ builder.jars.clear();
+ for (String jar : value.split(",")) {
+ addJar(jar);
+ }
+ } else if (validator.FILES.equals(name)) {
+ builder.files.clear();
+ for (String file : value.split(",")) {
+ addFile(file);
+ }
+ } else if (validator.PY_FILES.equals(name)) {
+ builder.pyFiles.clear();
+ for (String file : value.split(",")) {
+ addPyFile(file);
+ }
+ } else {
+ validator.parse(Arrays.asList(name, value));
+ builder.sparkArgs.add(name);
+ builder.sparkArgs.add(value);
+ }
+ return this;
+ }
+
/**
* Adds command line arguments for the application.
*
@@ -277,4 +346,32 @@ public Process launch() throws IOException {
return pb.start();
}
+ private static class ArgumentValidator extends SparkSubmitOptionParser {
+
+ private final boolean hasValue;
+
+ ArgumentValidator(boolean hasValue) {
+ this.hasValue = hasValue;
+ }
+
+ @Override
+ protected boolean handle(String opt, String value) {
+ if (value == null && hasValue) {
+ throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt));
+ }
+ return true;
+ }
+
+ @Override
+ protected boolean handleUnknown(String opt) {
+ // Do not fail on unknown arguments, to support future arguments added to SparkSubmit.
+ return true;
+ }
+
+ protected void handleExtraArgs(List extra) {
+ // No op.
+ }
+
+ };
+
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index 87c43aa9980e1..4f354cedee66f 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -76,7 +76,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
"spark-internal");
}
- private final List sparkArgs;
+ final List sparkArgs;
private final boolean printHelp;
/**
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java
index b88bba883ac65..5779eb3fc0f78 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java
@@ -51,6 +51,7 @@ class SparkSubmitOptionParser {
protected final String MASTER = "--master";
protected final String NAME = "--name";
protected final String PACKAGES = "--packages";
+ protected final String PACKAGES_EXCLUDE = "--exclude-packages";
protected final String PROPERTIES_FILE = "--properties-file";
protected final String PROXY_USER = "--proxy-user";
protected final String PY_FILES = "--py-files";
@@ -105,6 +106,7 @@ class SparkSubmitOptionParser {
{ NAME },
{ NUM_EXECUTORS },
{ PACKAGES },
+ { PACKAGES_EXCLUDE },
{ PRINCIPAL },
{ PROPERTIES_FILE },
{ PROXY_USER },
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
index 252d5abae1ca3..d0c26dd05679b 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
@@ -20,6 +20,7 @@
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -35,8 +36,54 @@ public class SparkLauncherSuite {
private static final Logger LOG = LoggerFactory.getLogger(SparkLauncherSuite.class);
+ @Test
+ public void testSparkArgumentHandling() throws Exception {
+ SparkLauncher launcher = new SparkLauncher()
+ .setSparkHome(System.getProperty("spark.test.home"));
+ SparkSubmitOptionParser opts = new SparkSubmitOptionParser();
+
+ launcher.addSparkArg(opts.HELP);
+ try {
+ launcher.addSparkArg(opts.PROXY_USER);
+ fail("Expected IllegalArgumentException.");
+ } catch (IllegalArgumentException e) {
+ // Expected.
+ }
+
+ launcher.addSparkArg(opts.PROXY_USER, "someUser");
+ try {
+ launcher.addSparkArg(opts.HELP, "someValue");
+ fail("Expected IllegalArgumentException.");
+ } catch (IllegalArgumentException e) {
+ // Expected.
+ }
+
+ launcher.addSparkArg("--future-argument");
+ launcher.addSparkArg("--future-argument", "someValue");
+
+ launcher.addSparkArg(opts.MASTER, "myMaster");
+ assertEquals("myMaster", launcher.builder.master);
+
+ launcher.addJar("foo");
+ launcher.addSparkArg(opts.JARS, "bar");
+ assertEquals(Arrays.asList("bar"), launcher.builder.jars);
+
+ launcher.addFile("foo");
+ launcher.addSparkArg(opts.FILES, "bar");
+ assertEquals(Arrays.asList("bar"), launcher.builder.files);
+
+ launcher.addPyFile("foo");
+ launcher.addSparkArg(opts.PY_FILES, "bar");
+ assertEquals(Arrays.asList("bar"), launcher.builder.pyFiles);
+
+ launcher.setConf("spark.foo", "foo");
+ launcher.addSparkArg(opts.CONF, "spark.foo=bar");
+ assertEquals("bar", launcher.builder.conf.get("spark.foo"));
+ }
+
@Test
public void testChildProcLauncher() throws Exception {
+ SparkSubmitOptionParser opts = new SparkSubmitOptionParser();
Map env = new HashMap();
env.put("SPARK_PRINT_LAUNCH_COMMAND", "1");
@@ -44,9 +91,12 @@ public void testChildProcLauncher() throws Exception {
.setSparkHome(System.getProperty("spark.test.home"))
.setMaster("local")
.setAppResource("spark-internal")
+ .addSparkArg(opts.CONF,
+ String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS))
.setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS,
"-Dfoo=bar -Dtest.name=-testChildProcLauncher")
.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
+ .addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow")
.setMainClass(SparkLauncherTestApp.class.getName())
.addAppArgs("proc");
final Process app = launcher.launch();
diff --git a/make-distribution.sh b/make-distribution.sh
index cac7032bb2e87..247a81341e4a4 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
DISTDIR="$SPARK_HOME/dist"
SPARK_TACHYON=false
-TACHYON_VERSION="0.6.4"
+TACHYON_VERSION="0.7.0"
TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz"
TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}"
@@ -219,7 +219,6 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR"
if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then
mkdir -p "$DISTDIR"/R/lib
cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib
- cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib
fi
# Download and copy in tachyon, if requested
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index aef2c019d2871..a3e59401c5cfb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -198,6 +198,6 @@ class PipelineModel private[ml] (
}
override def copy(extra: ParamMap): PipelineModel = {
- new PipelineModel(uid, stages.map(_.copy(extra)))
+ new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala
new file mode 100644
index 0000000000000..7429f9d652ac5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
+import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
+
+/**
+ * In-place DGEMM and DGEMV for Breeze
+ */
+private[ann] object BreezeUtil {
+
+ // TODO: switch to MLlib BLAS interface
+ private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N"
+
+ /**
+ * DGEMM: C := alpha * A * B + beta * C
+ * @param alpha alpha
+ * @param a A
+ * @param b B
+ * @param beta beta
+ * @param c C
+ */
+ def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = {
+ // TODO: add code if matrices isTranspose!!!
+ require(a.cols == b.rows, "A & B Dimension mismatch!")
+ require(a.rows == c.rows, "A & C Dimension mismatch!")
+ require(b.cols == c.cols, "A & C Dimension mismatch!")
+ NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols,
+ alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride,
+ beta, c.data, c.offset, c.rows)
+ }
+
+ /**
+ * DGEMV: y := alpha * A * x + beta * y
+ * @param alpha alpha
+ * @param a A
+ * @param x x
+ * @param beta beta
+ * @param y y
+ */
+ def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = {
+ require(a.cols == x.length, "A & b Dimension mismatch!")
+ NativeBLAS.dgemv(transposeString(a), a.rows, a.cols,
+ alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride,
+ beta, y.data, y.offset, y.stride)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
new file mode 100644
index 0000000000000..b5258ff348477
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -0,0 +1,882 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy,
+ sum => Bsum}
+import breeze.numerics.{log => Blog, sigmoid => Bsigmoid}
+
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.optimization._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Trait that holds Layer properties, that are needed to instantiate it.
+ * Implements Layer instantiation.
+ *
+ */
+private[ann] trait Layer extends Serializable {
+ /**
+ * Returns the instance of the layer based on weights provided
+ * @param weights vector with layer weights
+ * @param position position of weights in the vector
+ * @return the layer model
+ */
+ def getInstance(weights: Vector, position: Int): LayerModel
+
+ /**
+ * Returns the instance of the layer with random generated weights
+ * @param seed seed
+ * @return the layer model
+ */
+ def getInstance(seed: Long): LayerModel
+}
+
+/**
+ * Trait that holds Layer weights (or parameters).
+ * Implements functions needed for forward propagation, computing delta and gradient.
+ * Can return weights in Vector format.
+ */
+private[ann] trait LayerModel extends Serializable {
+ /**
+ * number of weights
+ */
+ val size: Int
+
+ /**
+ * Evaluates the data (process the data through the layer)
+ * @param data data
+ * @return processed data
+ */
+ def eval(data: BDM[Double]): BDM[Double]
+
+ /**
+ * Computes the delta for back propagation
+ * @param nextDelta delta of the next layer
+ * @param input input data
+ * @return delta
+ */
+ def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double]
+
+ /**
+ * Computes the gradient
+ * @param delta delta for this layer
+ * @param input input data
+ * @return gradient
+ */
+ def grad(delta: BDM[Double], input: BDM[Double]): Array[Double]
+
+ /**
+ * Returns weights for the layer in a single vector
+ * @return layer weights
+ */
+ def weights(): Vector
+}
+
+/**
+ * Layer properties of affine transformations, that is y=A*x+b
+ * @param numIn number of inputs
+ * @param numOut number of outputs
+ */
+private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer {
+
+ override def getInstance(weights: Vector, position: Int): LayerModel = {
+ AffineLayerModel(this, weights, position)
+ }
+
+ override def getInstance(seed: Long = 11L): LayerModel = {
+ AffineLayerModel(this, seed)
+ }
+}
+
+/**
+ * Model of Affine layer y=A*x+b
+ * @param w weights (matrix A)
+ * @param b bias (vector b)
+ */
+private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel {
+ val size = w.size + b.length
+ val gwb = new Array[Double](size)
+ private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb)
+ private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size)
+ private var z: BDM[Double] = null
+ private var d: BDM[Double] = null
+ private var ones: BDV[Double] = null
+
+ override def eval(data: BDM[Double]): BDM[Double] = {
+ if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols)
+ z(::, *) := b
+ BreezeUtil.dgemm(1.0, w, data, 1.0, z)
+ z
+ }
+
+ override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
+ if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols)
+ BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d)
+ d
+ }
+
+ override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = {
+ BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw)
+ if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols)
+ BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb)
+ gwb
+ }
+
+ override def weights(): Vector = AffineLayerModel.roll(w, b)
+}
+
+/**
+ * Fabric for Affine layer models
+ */
+private[ann] object AffineLayerModel {
+
+ /**
+ * Creates a model of Affine layer
+ * @param layer layer properties
+ * @param weights vector with weights
+ * @param position position of weights in the vector
+ * @return model of Affine layer
+ */
+ def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = {
+ val (w, b) = unroll(weights, position, layer.numIn, layer.numOut)
+ new AffineLayerModel(w, b)
+ }
+
+ /**
+ * Creates a model of Affine layer
+ * @param layer layer properties
+ * @param seed seed
+ * @return model of Affine layer
+ */
+ def apply(layer: AffineLayer, seed: Long): AffineLayerModel = {
+ val (w, b) = randomWeights(layer.numIn, layer.numOut, seed)
+ new AffineLayerModel(w, b)
+ }
+
+ /**
+ * Unrolls the weights from the vector
+ * @param weights vector with weights
+ * @param position position of weights for this layer
+ * @param numIn number of layer inputs
+ * @param numOut number of layer outputs
+ * @return matrix A and vector b
+ */
+ def unroll(
+ weights: Vector,
+ position: Int,
+ numIn: Int,
+ numOut: Int): (BDM[Double], BDV[Double]) = {
+ val weightsCopy = weights.toArray
+ // TODO: the array is not copied to BDMs, make sure this is OK!
+ val a = new BDM[Double](numOut, numIn, weightsCopy, position)
+ val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut)
+ (a, b)
+ }
+
+ /**
+ * Roll the layer weights into a vector
+ * @param a matrix A
+ * @param b vector b
+ * @return vector of weights
+ */
+ def roll(a: BDM[Double], b: BDV[Double]): Vector = {
+ val result = new Array[Double](a.size + b.length)
+ // TODO: make sure that we need to copy!
+ System.arraycopy(a.toArray, 0, result, 0, a.size)
+ System.arraycopy(b.toArray, 0, result, a.size, b.length)
+ Vectors.dense(result)
+ }
+
+ /**
+ * Generate random weights for the layer
+ * @param numIn number of inputs
+ * @param numOut number of outputs
+ * @param seed seed
+ * @return (matrix A, vector b)
+ */
+ def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = {
+ val rand: XORShiftRandom = new XORShiftRandom(seed)
+ val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn }
+ val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn }
+ (weights, bias)
+ }
+}
+
+/**
+ * Trait for functions and their derivatives for functional layers
+ */
+private[ann] trait ActivationFunction extends Serializable {
+
+ /**
+ * Implements a function
+ * @param x input data
+ * @param y output data
+ */
+ def eval(x: BDM[Double], y: BDM[Double]): Unit
+
+ /**
+ * Implements a derivative of a function (needed for the back propagation)
+ * @param x input data
+ * @param y output data
+ */
+ def derivative(x: BDM[Double], y: BDM[Double]): Unit
+
+ /**
+ * Implements a cross entropy error of a function.
+ * Needed if the functional layer that contains this function is the output layer
+ * of the network.
+ * @param target target output
+ * @param output computed output
+ * @param result intermediate result
+ * @return cross-entropy
+ */
+ def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
+
+ /**
+ * Implements a mean squared error of a function
+ * @param target target output
+ * @param output computed output
+ * @param result intermediate result
+ * @return mean squared error
+ */
+ def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
+}
+
+/**
+ * Implements in-place application of functions
+ */
+private[ann] object ActivationFunction {
+
+ def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = {
+ var i = 0
+ while (i < x.rows) {
+ var j = 0
+ while (j < x.cols) {
+ y(i, j) = func(x(i, j))
+ j += 1
+ }
+ i += 1
+ }
+ }
+
+ def apply(
+ x1: BDM[Double],
+ x2: BDM[Double],
+ y: BDM[Double],
+ func: (Double, Double) => Double): Unit = {
+ var i = 0
+ while (i < x1.rows) {
+ var j = 0
+ while (j < x1.cols) {
+ y(i, j) = func(x1(i, j), x2(i, j))
+ j += 1
+ }
+ i += 1
+ }
+ }
+}
+
+/**
+ * Implements SoftMax activation function
+ */
+private[ann] class SoftmaxFunction extends ActivationFunction {
+ override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
+ var j = 0
+ // find max value to make sure later that exponent is computable
+ while (j < x.cols) {
+ var i = 0
+ var max = Double.MinValue
+ while (i < x.rows) {
+ if (x(i, j) > max) {
+ max = x(i, j)
+ }
+ i += 1
+ }
+ var sum = 0.0
+ i = 0
+ while (i < x.rows) {
+ val res = Math.exp(x(i, j) - max)
+ y(i, j) = res
+ sum += res
+ i += 1
+ }
+ i = 0
+ while (i < x.rows) {
+ y(i, j) /= sum
+ i += 1
+ }
+ j += 1
+ }
+ }
+
+ override def crossEntropy(
+ output: BDM[Double],
+ target: BDM[Double],
+ result: BDM[Double]): Double = {
+ def m(o: Double, t: Double): Double = o - t
+ ActivationFunction(output, target, result, m)
+ -Bsum( target :* Blog(output)) / output.cols
+ }
+
+ override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
+ def sd(z: Double): Double = (1 - z) * z
+ ActivationFunction(x, y, sd)
+ }
+
+ override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
+ throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.")
+ }
+}
+
+/**
+ * Implements Sigmoid activation function
+ */
+private[ann] class SigmoidFunction extends ActivationFunction {
+ override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
+ def s(z: Double): Double = Bsigmoid(z)
+ ActivationFunction(x, y, s)
+ }
+
+ override def crossEntropy(
+ output: BDM[Double],
+ target: BDM[Double],
+ result: BDM[Double]): Double = {
+ def m(o: Double, t: Double): Double = o - t
+ ActivationFunction(output, target, result, m)
+ -Bsum(target :* Blog(output)) / output.cols
+ }
+
+ override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
+ def sd(z: Double): Double = (1 - z) * z
+ ActivationFunction(x, y, sd)
+ }
+
+ override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
+ // TODO: make it readable
+ def m(o: Double, t: Double): Double = (o - t)
+ ActivationFunction(output, target, result, m)
+ val e = Bsum(result :* result) / 2 / output.cols
+ def m2(x: Double, o: Double) = x * (o - o * o)
+ ActivationFunction(result, output, result, m2)
+ e
+ }
+}
+
+/**
+ * Functional layer properties, y = f(x)
+ * @param activationFunction activation function
+ */
+private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer {
+ override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L)
+
+ override def getInstance(seed: Long): LayerModel =
+ FunctionalLayerModel(this)
+}
+
+/**
+ * Functional layer model. Holds no weights.
+ * @param activationFunction activation function
+ */
+private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction)
+ extends LayerModel {
+ val size = 0
+ // matrices for in-place computations
+ // outputs
+ private var f: BDM[Double] = null
+ // delta
+ private var d: BDM[Double] = null
+ // matrix for error computation
+ private var e: BDM[Double] = null
+ // delta gradient
+ private lazy val dg = new Array[Double](0)
+
+ override def eval(data: BDM[Double]): BDM[Double] = {
+ if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols)
+ activationFunction.eval(data, f)
+ f
+ }
+
+ override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
+ if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols)
+ activationFunction.derivative(input, d)
+ d :*= nextDelta
+ d
+ }
+
+ override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg
+
+ override def weights(): Vector = Vectors.dense(new Array[Double](0))
+
+ def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
+ if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
+ val error = activationFunction.crossEntropy(output, target, e)
+ (e, error)
+ }
+
+ def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
+ if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
+ val error = activationFunction.squared(output, target, e)
+ (e, error)
+ }
+
+ def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
+ // TODO: allow user pick error
+ activationFunction match {
+ case sigmoid: SigmoidFunction => squared(output, target)
+ case softmax: SoftmaxFunction => crossEntropy(output, target)
+ }
+ }
+}
+
+/**
+ * Fabric of functional layer models
+ */
+private[ann] object FunctionalLayerModel {
+ def apply(layer: FunctionalLayer): FunctionalLayerModel =
+ new FunctionalLayerModel(layer.activationFunction)
+}
+
+/**
+ * Trait for the artificial neural network (ANN) topology properties
+ */
+private[ann] trait Topology extends Serializable{
+ def getInstance(weights: Vector): TopologyModel
+ def getInstance(seed: Long): TopologyModel
+}
+
+/**
+ * Trait for ANN topology model
+ */
+private[ann] trait TopologyModel extends Serializable{
+ /**
+ * Forward propagation
+ * @param data input data
+ * @return array of outputs for each of the layers
+ */
+ def forward(data: BDM[Double]): Array[BDM[Double]]
+
+ /**
+ * Prediction of the model
+ * @param data input data
+ * @return prediction
+ */
+ def predict(data: Vector): Vector
+
+ /**
+ * Computes gradient for the network
+ * @param data input data
+ * @param target target output
+ * @param cumGradient cumulative gradient
+ * @param blockSize block size
+ * @return error
+ */
+ def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector,
+ blockSize: Int): Double
+
+ /**
+ * Returns the weights of the ANN
+ * @return weights
+ */
+ def weights(): Vector
+}
+
+/**
+ * Feed forward ANN
+ * @param layers
+ */
+private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology {
+ override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
+
+ override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed)
+}
+
+/**
+ * Factory for some of the frequently-used topologies
+ */
+private[ml] object FeedForwardTopology {
+ /**
+ * Creates a feed forward topology from the array of layers
+ * @param layers array of layers
+ * @return feed forward topology
+ */
+ def apply(layers: Array[Layer]): FeedForwardTopology = {
+ new FeedForwardTopology(layers)
+ }
+
+ /**
+ * Creates a multi-layer perceptron
+ * @param layerSizes sizes of layers including input and output size
+ * @param softmax wether to use SoftMax or Sigmoid function for an output layer.
+ * Softmax is default
+ * @return multilayer perceptron topology
+ */
+ def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = {
+ val layers = new Array[Layer]((layerSizes.length - 1) * 2)
+ for(i <- 0 until layerSizes.length - 1){
+ layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1))
+ layers(i * 2 + 1) =
+ if (softmax && i == layerSizes.length - 2) {
+ new FunctionalLayer(new SoftmaxFunction())
+ } else {
+ new FunctionalLayer(new SigmoidFunction())
+ }
+ }
+ FeedForwardTopology(layers)
+ }
+}
+
+/**
+ * Model of Feed Forward Neural Network.
+ * Implements forward, gradient computation and can return weights in vector format.
+ * @param layerModels models of layers
+ * @param topology topology of the network
+ */
+private[ml] class FeedForwardModel private(
+ val layerModels: Array[LayerModel],
+ val topology: FeedForwardTopology) extends TopologyModel {
+ override def forward(data: BDM[Double]): Array[BDM[Double]] = {
+ val outputs = new Array[BDM[Double]](layerModels.length)
+ outputs(0) = layerModels(0).eval(data)
+ for (i <- 1 until layerModels.length) {
+ outputs(i) = layerModels(i).eval(outputs(i-1))
+ }
+ outputs
+ }
+
+ override def computeGradient(
+ data: BDM[Double],
+ target: BDM[Double],
+ cumGradient: Vector,
+ realBatchSize: Int): Double = {
+ val outputs = forward(data)
+ val deltas = new Array[BDM[Double]](layerModels.length)
+ val L = layerModels.length - 1
+ val (newE, newError) = layerModels.last match {
+ case flm: FunctionalLayerModel => flm.error(outputs.last, target)
+ case _ =>
+ throw new UnsupportedOperationException("Non-functional layer not supported at the top")
+ }
+ deltas(L) = new BDM[Double](0, 0)
+ deltas(L - 1) = newE
+ for (i <- (L - 2) to (0, -1)) {
+ deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1))
+ }
+ val grads = new Array[Array[Double]](layerModels.length)
+ for (i <- 0 until layerModels.length) {
+ val input = if (i==0) data else outputs(i - 1)
+ grads(i) = layerModels(i).grad(deltas(i), input)
+ }
+ // update cumGradient
+ val cumGradientArray = cumGradient.toArray
+ var offset = 0
+ // TODO: extract roll
+ for (i <- 0 until grads.length) {
+ val gradArray = grads(i)
+ var k = 0
+ while (k < gradArray.length) {
+ cumGradientArray(offset + k) += gradArray(k)
+ k += 1
+ }
+ offset += gradArray.length
+ }
+ newError
+ }
+
+ // TODO: do we really need to copy the weights? they should be read-only
+ override def weights(): Vector = {
+ // TODO: extract roll
+ var size = 0
+ for (i <- 0 until layerModels.length) {
+ size += layerModels(i).size
+ }
+ val array = new Array[Double](size)
+ var offset = 0
+ for (i <- 0 until layerModels.length) {
+ val layerWeights = layerModels(i).weights().toArray
+ System.arraycopy(layerWeights, 0, array, offset, layerWeights.length)
+ offset += layerWeights.length
+ }
+ Vectors.dense(array)
+ }
+
+ override def predict(data: Vector): Vector = {
+ val size = data.size
+ val result = forward(new BDM[Double](size, 1, data.toArray))
+ Vectors.dense(result.last.toArray)
+ }
+}
+
+/**
+ * Fabric for feed forward ANN models
+ */
+private[ann] object FeedForwardModel {
+
+ /**
+ * Creates a model from a topology and weights
+ * @param topology topology
+ * @param weights weights
+ * @return model
+ */
+ def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
+ val layers = topology.layers
+ val layerModels = new Array[LayerModel](layers.length)
+ var offset = 0
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).getInstance(weights, offset)
+ offset += layerModels(i).size
+ }
+ new FeedForwardModel(layerModels, topology)
+ }
+
+ /**
+ * Creates a model given a topology and seed
+ * @param topology topology
+ * @param seed seed for generating the weights
+ * @return model
+ */
+ def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
+ val layers = topology.layers
+ val layerModels = new Array[LayerModel](layers.length)
+ var offset = 0
+ for(i <- 0 until layers.length){
+ layerModels(i) = layers(i).getInstance(seed)
+ offset += layerModels(i).size
+ }
+ new FeedForwardModel(layerModels, topology)
+ }
+}
+
+/**
+ * Neural network gradient. Does nothing but calling Model's gradient
+ * @param topology topology
+ * @param dataStacker data stacker
+ */
+private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient {
+
+ override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
+ val gradient = Vectors.zeros(weights.size)
+ val loss = compute(data, label, weights, gradient)
+ (gradient, loss)
+ }
+
+ override def compute(
+ data: Vector,
+ label: Double,
+ weights: Vector,
+ cumGradient: Vector): Double = {
+ val (input, target, realBatchSize) = dataStacker.unstack(data)
+ val model = topology.getInstance(weights)
+ model.computeGradient(input, target, cumGradient, realBatchSize)
+ }
+}
+
+/**
+ * Stacks pairs of training samples (input, output) in one vector allowing them to pass
+ * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks
+ * or matrices of inputs and outputs and then stack them in one vector.
+ * This can be used for further batch computations after unstacking.
+ * @param stackSize stack size
+ * @param inputSize size of the input vectors
+ * @param outputSize size of the output vectors
+ */
+private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
+ extends Serializable {
+
+ /**
+ * Stacks the data
+ * @param data RDD of vector pairs
+ * @return RDD of double (always zero) and vector that contains the stacked vectors
+ */
+ def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = {
+ val stackedData = if (stackSize == 1) {
+ data.map { v =>
+ (0.0,
+ Vectors.fromBreeze(BDV.vertcat(
+ v._1.toBreeze.toDenseVector,
+ v._2.toBreeze.toDenseVector))
+ ) }
+ } else {
+ data.mapPartitions { it =>
+ it.grouped(stackSize).map { seq =>
+ val size = seq.size
+ val bigVector = new Array[Double](inputSize * size + outputSize * size)
+ var i = 0
+ seq.foreach { case (in, out) =>
+ System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize)
+ System.arraycopy(out.toArray, 0, bigVector,
+ inputSize * size + i * outputSize, outputSize)
+ i += 1
+ }
+ (0.0, Vectors.dense(bigVector))
+ }
+ }
+ }
+ stackedData
+ }
+
+ /**
+ * Unstack the stacked vectors into matrices for batch operations
+ * @param data stacked vector
+ * @return pair of matrices holding input and output data and the real stack size
+ */
+ def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = {
+ val arrData = data.toArray
+ val realStackSize = arrData.length / (inputSize + outputSize)
+ val input = new BDM(inputSize, realStackSize, arrData)
+ val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize)
+ (input, target, realStackSize)
+ }
+}
+
+/**
+ * Simple updater
+ */
+private[ann] class ANNUpdater extends Updater {
+
+ override def compute(
+ weightsOld: Vector,
+ gradient: Vector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double): (Vector, Double) = {
+ val thisIterStepSize = stepSize
+ val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
+ Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
+ (Vectors.fromBreeze(brzWeights), 0)
+ }
+}
+
+/**
+ * MLlib-style trainer class that trains a network given the data and topology
+ * @param topology topology of ANN
+ * @param inputSize input size
+ * @param outputSize output size
+ */
+private[ml] class FeedForwardTrainer(
+ topology: Topology,
+ val inputSize: Int,
+ val outputSize: Int) extends Serializable {
+
+ // TODO: what if we need to pass random seed?
+ private var _weights = topology.getInstance(11L).weights()
+ private var _stackSize = 128
+ private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize)
+ private var _gradient: Gradient = new ANNGradient(topology, dataStacker)
+ private var _updater: Updater = new ANNUpdater()
+ private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100)
+
+ /**
+ * Returns weights
+ * @return weights
+ */
+ def getWeights: Vector = _weights
+
+ /**
+ * Sets weights
+ * @param value weights
+ * @return trainer
+ */
+ def setWeights(value: Vector): FeedForwardTrainer = {
+ _weights = value
+ this
+ }
+
+ /**
+ * Sets the stack size
+ * @param value stack size
+ * @return trainer
+ */
+ def setStackSize(value: Int): FeedForwardTrainer = {
+ _stackSize = value
+ dataStacker = new DataStacker(value, inputSize, outputSize)
+ this
+ }
+
+ /**
+ * Sets the SGD optimizer
+ * @return SGD optimizer
+ */
+ def SGDOptimizer: GradientDescent = {
+ val sgd = new GradientDescent(_gradient, _updater)
+ optimizer = sgd
+ sgd
+ }
+
+ /**
+ * Sets the LBFGS optimizer
+ * @return LBGS optimizer
+ */
+ def LBFGSOptimizer: LBFGS = {
+ val lbfgs = new LBFGS(_gradient, _updater)
+ optimizer = lbfgs
+ lbfgs
+ }
+
+ /**
+ * Sets the updater
+ * @param value updater
+ * @return trainer
+ */
+ def setUpdater(value: Updater): FeedForwardTrainer = {
+ _updater = value
+ updateUpdater(value)
+ this
+ }
+
+ /**
+ * Sets the gradient
+ * @param value gradient
+ * @return trainer
+ */
+ def setGradient(value: Gradient): FeedForwardTrainer = {
+ _gradient = value
+ updateGradient(value)
+ this
+ }
+
+ private[this] def updateGradient(gradient: Gradient): Unit = {
+ optimizer match {
+ case lbfgs: LBFGS => lbfgs.setGradient(gradient)
+ case sgd: GradientDescent => sgd.setGradient(gradient)
+ case other => throw new UnsupportedOperationException(
+ s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.")
+ }
+ }
+
+ private[this] def updateUpdater(updater: Updater): Unit = {
+ optimizer match {
+ case lbfgs: LBFGS => lbfgs.setUpdater(updater)
+ case sgd: GradientDescent => sgd.setUpdater(updater)
+ case other => throw new UnsupportedOperationException(
+ s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.")
+ }
+ }
+
+ /**
+ * Trains the ANN
+ * @param data RDD of input and output vector pairs
+ * @return model
+ */
+ def train(data: RDD[(Vector, Vector)]): TopologyModel = {
+ val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights)
+ topology.getInstance(newWeights)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 581d8fa7749be..45df557a89908 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,14 +18,13 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 36fe1bd40469c..6f70b96b17ec6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -18,12 +18,11 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
@@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
*/
@Experimental
final class DecisionTreeClassifier(override val uid: String)
- extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("dtc"))
@@ -106,8 +105,9 @@ object DecisionTreeClassifier {
@Experimental
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
- override val rootNode: Node)
- extends PredictionModel[Vector, DecisionTreeClassificationModel]
+ override val rootNode: Node,
+ override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
require(rootNode != null,
@@ -117,14 +117,31 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
+ private[ml] def this(rootNode: Node, numClasses: Int) =
+ this(Identifiable.randomUID("dtc"), rootNode, numClasses)
override protected def predict(features: Vector): Double = {
- rootNode.predict(features)
+ rootNode.predictImpl(features).prediction
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
+ Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
}
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
- copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
+ copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
+ .setParent(parent)
}
override def toString: String = {
@@ -149,6 +166,6 @@ private[ml] object DecisionTreeClassificationModel {
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
- new DecisionTreeClassificationModel(uid, rootNode)
+ new DecisionTreeClassificationModel(uid, rootNode, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index eb0b1a0a405fc..3073a2a61ce83 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -190,13 +190,13 @@ final class GBTClassificationModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
- val treePredictions = _trees.map(_.rootNode.predict(features))
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0
}
override def copy(extra: ParamMap): GBTClassificationModel = {
- copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
+ copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fc9199fb4602..21fbe38ca8233 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -30,10 +30,11 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel
/**
@@ -41,12 +42,115 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasThreshold with HasStandardization
+ with HasStandardization with HasThreshold {
+
+ /**
+ * Set threshold in binary classification, in range [0, 1].
+ *
+ * If the estimated probability of class label 1 is > threshold, then predict 1, else 0.
+ * A high threshold encourages the model to predict 0 more often;
+ * a low threshold encourages the model to predict 1 more often.
+ *
+ * Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`.
+ * When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared.
+ * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
+ * equivalent.
+ *
+ * Default is 0.5.
+ * @group setParam
+ */
+ def setThreshold(value: Double): this.type = {
+ if (isSet(thresholds)) clear(thresholds)
+ set(threshold, value)
+ }
+
+ /**
+ * Get threshold for binary classification.
+ *
+ * If [[threshold]] is set, returns that value.
+ * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification),
+ * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
+ * Otherwise, returns [[threshold]] default value.
+ *
+ * @group getParam
+ * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2.
+ */
+ override def getThreshold: Double = {
+ checkThresholdConsistency()
+ if (isSet(thresholds)) {
+ val ts = $(thresholds)
+ require(ts.length == 2, "Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(","))
+ 1.0 / (1.0 + ts(0) / ts(1))
+ } else {
+ $(threshold)
+ }
+ }
+
+ /**
+ * Set thresholds in multiclass (or binary) classification to adjust the probability of
+ * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+ * The class with largest value p/t is predicted, where p is the original probability of that
+ * class and t is the class' threshold.
+ *
+ * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
+ * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
+ * equivalent.
+ *
+ * @group setParam
+ */
+ def setThresholds(value: Array[Double]): this.type = {
+ if (isSet(threshold)) clear(threshold)
+ set(thresholds, value)
+ }
+
+ /**
+ * Get thresholds for binary or multiclass classification.
+ *
+ * If [[thresholds]] is set, return its value.
+ * Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary
+ * classification: (1-threshold, threshold).
+ * If neither are set, throw an exception.
+ *
+ * @group getParam
+ */
+ override def getThresholds: Array[Double] = {
+ checkThresholdConsistency()
+ if (!isSet(thresholds) && isSet(threshold)) {
+ val t = $(threshold)
+ Array(1-t, t)
+ } else {
+ $(thresholds)
+ }
+ }
+
+ /**
+ * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
+ * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
+ */
+ protected def checkThresholdConsistency(): Unit = {
+ if (isSet(threshold) && isSet(thresholds)) {
+ val ts = $(thresholds)
+ require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" +
+ s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" +
+ s" classification, but Param thresholds is set with length ${ts.length}." +
+ " Clear one Param value to fix this problem.")
+ val t = 1.0 / (1.0 + ts(0) / ts(1))
+ require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" +
+ s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)")
+ }
+ }
+
+ override def validateParams(): Unit = {
+ checkThresholdConsistency()
+ }
+}
/**
* :: Experimental ::
* Logistic regression.
- * Currently, this class only supports binary classification.
+ * Currently, this class only supports binary classification. It will support multiclass
+ * in the future.
*/
@Experimental
class LogisticRegression(override val uid: String)
@@ -94,25 +198,29 @@ class LogisticRegression(override val uid: String)
* Whether to fit an intercept term.
* Default is true.
* @group setParam
- * */
+ */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
/**
* Whether to standardize the training features before fitting the model.
* The coefficients of models will be always returned on the original scale,
- * so it will be transparent for users. Note that when no regularization,
- * with or without standardization, the models should be always converged to
- * the same solution.
+ * so it will be transparent for users. Note that with/without standardization,
+ * the models should be always converged to the same solution when no regularization
+ * is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
* @group setParam
- * */
+ */
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)
- /** @group setParam */
- def setThreshold(value: Double): this.type = set(threshold, value)
- setDefault(threshold -> 0.5)
+ override def setThreshold(value: Double): this.type = super.setThreshold(value)
+
+ override def getThreshold: Double = super.getThreshold
+
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override def getThresholds: Array[Double] = super.getThresholds
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
@@ -252,7 +360,13 @@ class LogisticRegression(override val uid: String)
if (handlePersistence) instances.unpersist()
- copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
+ val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+ model.transform(dataset),
+ $(probabilityCol),
+ $(labelCol),
+ objectiveHistory)
+ model.setSummary(logRegSummary)
}
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
@@ -270,8 +384,13 @@ class LogisticRegressionModel private[ml] (
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams {
- /** @group setParam */
- def setThreshold(value: Double): this.type = set(threshold, value)
+ override def setThreshold(value: Double): this.type = super.setThreshold(value)
+
+ override def getThreshold: Double = super.getThreshold
+
+ override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
+
+ override def getThresholds: Array[Double] = super.getThresholds
/** Margin (rawPrediction) for class label 1. For binary classification only. */
private val margin: Vector => Double = (features) => {
@@ -286,11 +405,44 @@ class LogisticRegressionModel private[ml] (
override val numClasses: Int = 2
+ private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ def summary: LogisticRegressionTrainingSummary = trainingSummary match {
+ case Some(summ) => summ
+ case None =>
+ throw new SparkException(
+ "No training summary available for this LogisticRegressionModel",
+ new NullPointerException())
+ }
+
+ private[classification] def setSummary(
+ summary: LogisticRegressionTrainingSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /** Indicates whether a training summary exists for this model instance. */
+ def hasSummary: Boolean = trainingSummary.isDefined
+
+ /**
+ * Evaluates the model on a testset.
+ * @param dataset Test dataset to evaluate model on.
+ */
+ // TODO: decide on a good name before exposing to public API
+ private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+ new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
+ }
+
/**
* Predict label for the given feature vector.
- * The behavior of this can be adjusted using [[threshold]].
+ * The behavior of this can be adjusted using [[thresholds]].
*/
override protected def predict(features: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (score(features) > getThreshold) 1 else 0
}
@@ -316,10 +468,11 @@ class LogisticRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
- copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
+ copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
val t = getThreshold
val rawThreshold = if (t == 0.0) {
Double.NegativeInfinity
@@ -332,6 +485,7 @@ class LogisticRegressionModel private[ml] (
}
override protected def probability2prediction(probability: Vector): Double = {
+ // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
}
}
@@ -407,6 +561,128 @@ private[classification] class MultiClassSummarizer extends Serializable {
}
}
+/**
+ * Abstraction for multinomial Logistic Regression Training results.
+ */
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+ /** objective function (scaled loss + regularization) at each iteration. */
+ def objectiveHistory: Array[Double]
+
+ /** Number of training iterations until termination */
+ def totalIterations: Int = objectiveHistory.length
+
+}
+
+/**
+ * Abstraction for Logistic Regression Results for a given model.
+ */
+sealed trait LogisticRegressionSummary extends Serializable {
+
+ /** Dataframe outputted by the model's `transform` method. */
+ def predictions: DataFrame
+
+ /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */
+ def probabilityCol: String
+
+ /** Field in "predictions" which gives the the true label of each sample. */
+ def labelCol: String
+
+}
+
+/**
+ * :: Experimental ::
+ * Logistic regression training results.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ * each sample as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+@Experimental
+class BinaryLogisticRegressionTrainingSummary private[classification] (
+ predictions: DataFrame,
+ probabilityCol: String,
+ labelCol: String,
+ val objectiveHistory: Array[Double])
+ extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
+ with LogisticRegressionTrainingSummary {
+
+}
+
+/**
+ * :: Experimental ::
+ * Binary Logistic regression results for a given model.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ * each sample.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ */
+@Experimental
+class BinaryLogisticRegressionSummary private[classification] (
+ @transient override val predictions: DataFrame,
+ override val probabilityCol: String,
+ override val labelCol: String) extends LogisticRegressionSummary {
+
+ private val sqlContext = predictions.sqlContext
+ import sqlContext.implicits._
+
+ /**
+ * Returns a BinaryClassificationMetrics object.
+ */
+ // TODO: Allow the user to vary the number of bins using a setBins method in
+ // BinaryClassificationMetrics. For now the default is set to 100.
+ @transient private val binaryMetrics = new BinaryClassificationMetrics(
+ predictions.select(probabilityCol, labelCol).map {
+ case Row(score: Vector, label: Double) => (score(1), label)
+ }, 100
+ )
+
+ /**
+ * Returns the receiver operating characteristic (ROC) curve,
+ * which is an Dataframe having two fields (FPR, TPR)
+ * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+ * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+ */
+ @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
+
+ /**
+ * Computes the area under the receiver operating characteristic (ROC) curve.
+ */
+ lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
+
+ /**
+ * Returns the precision-recall curve, which is an Dataframe containing
+ * two fields recall, precision with (0.0, 1.0) prepended to it.
+ */
+ @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
+
+ /**
+ * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
+ */
+ @transient lazy val fMeasureByThreshold: DataFrame = {
+ binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
+ }
+
+ /**
+ * Returns a dataframe with two fields (threshold, precision) curve.
+ * Every possible probability obtained in transforming the dataset are used
+ * as thresholds used in calculating the precision.
+ */
+ @transient lazy val precisionByThreshold: DataFrame = {
+ binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
+ }
+
+ /**
+ * Returns a dataframe with two fields (threshold, recall) curve.
+ * Every possible probability obtained in transforming the dataset are used
+ * as thresholds used in calculating the recall.
+ */
+ @transient lazy val recallByThreshold: DataFrame = {
+ binaryMetrics.recallByThreshold().toDF("threshold", "recall")
+ }
+}
+
/**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for samples in sparse or dense vector in a online fashion.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
new file mode 100644
index 0000000000000..c154561886585
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
+import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
+import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology}
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql.DataFrame
+
+/** Params for Multilayer Perceptron. */
+private[ml] trait MultilayerPerceptronParams extends PredictorParams
+ with HasSeed with HasMaxIter with HasTol {
+ /**
+ * Layer sizes including input size and output size.
+ * @group param
+ */
+ final val layers: IntArrayParam = new IntArrayParam(this, "layers",
+ "Sizes of layers from input layer to output layer" +
+ " E.g., Array(780, 100, 10) means 780 inputs, " +
+ "one hidden layer with 100 neurons and output layer of 10 neurons.",
+ // TODO: how to check ALSO that all elements are greater than 0?
+ ParamValidators.arrayLengthGt(1)
+ )
+
+ /** @group setParam */
+ def setLayers(value: Array[Int]): this.type = set(layers, value)
+
+ /** @group getParam */
+ final def getLayers: Array[Int] = $(layers)
+
+ /**
+ * Block size for stacking input data in matrices to speed up the computation.
+ * Data is stacked within partitions. If block size is more than remaining data in
+ * a partition then it is adjusted to the size of this data.
+ * Recommended size is between 10 and 1000.
+ * @group expertParam
+ */
+ final val blockSize: IntParam = new IntParam(this, "blockSize",
+ "Block size for stacking input data in matrices. Data is stacked within partitions." +
+ " If block size is more than remaining data in a partition then " +
+ "it is adjusted to the size of this data. Recommended size is between 10 and 1000",
+ ParamValidators.gt(0))
+
+ /** @group setParam */
+ def setBlockSize(value: Int): this.type = set(blockSize, value)
+
+ /** @group getParam */
+ final def getBlockSize: Int = $(blockSize)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ * @group setParam
+ */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-4.
+ * @group setParam
+ */
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /**
+ * Set the seed for weights initialization.
+ * @group setParam
+ */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128)
+}
+
+/** Label to vector converter. */
+private object LabelConverter {
+ // TODO: Use OneHotEncoder instead
+ /**
+ * Encodes a label as a vector.
+ * Returns a vector of given length with zeroes at all positions
+ * and value 1.0 at the position that corresponds to the label.
+ *
+ * @param labeledPoint labeled point
+ * @param labelCount total number of labels
+ * @return pair of features and vector encoding of a label
+ */
+ def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = {
+ val output = Array.fill(labelCount)(0.0)
+ output(labeledPoint.label.toInt) = 1.0
+ (labeledPoint.features, Vectors.dense(output))
+ }
+
+ /**
+ * Converts a vector to a label.
+ * Returns the position of the maximal element of a vector.
+ *
+ * @param output label encoded with a vector
+ * @return label
+ */
+ def decodeLabel(output: Vector): Double = {
+ output.argmax.toDouble
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Classifier trainer based on the Multilayer Perceptron.
+ * Each layer has sigmoid activation function, output layer has softmax.
+ * Number of inputs has to be equal to the size of feature vectors.
+ * Number of outputs has to be equal to the total number of labels.
+ *
+ */
+@Experimental
+class MultilayerPerceptronClassifier(override val uid: String)
+ extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
+ with MultilayerPerceptronParams {
+
+ def this() = this(Identifiable.randomUID("mlpc"))
+
+ override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
+
+ /**
+ * Train a model using the given dataset and parameters.
+ * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
+ * and copying parameters into the model.
+ *
+ * @param dataset Training dataset
+ * @return Fitted model
+ */
+ override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
+ val myLayers = $(layers)
+ val labels = myLayers.last
+ val lpData = extractLabeledPoints(dataset)
+ val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels))
+ val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true)
+ val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
+ FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter))
+ FeedForwardTrainer.setStackSize($(blockSize))
+ val mlpModel = FeedForwardTrainer.train(data)
+ new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights())
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Classification model based on the Multilayer Perceptron.
+ * Each layer has sigmoid activation function, output layer has softmax.
+ * @param uid uid
+ * @param layers array of layer sizes including input and output layers
+ * @param weights vector of initial weights for the model that consists of the weights of layers
+ * @return prediction model
+ */
+@Experimental
+class MultilayerPerceptronClassificationModel private[ml] (
+ override val uid: String,
+ layers: Array[Int],
+ weights: Vector)
+ extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
+ with Serializable {
+
+ private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
+
+ /**
+ * Predict label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ */
+ override protected def predict(features: Vector): Double = {
+ LabelConverter.decodeLabel(mlpModel.predict(features))
+ }
+
+ override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
+ copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 1f547e4a98af7..97cbaf1fa8761 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* (default = 1.0).
* @group param
*/
- final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
+ final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.",
ParamValidators.gtEq(0))
/** @group getParam */
- final def getLambda: Double = $(lambda)
+ final def getSmoothing: Double = $(smoothing)
/**
* The model type which is a string (case-sensitive).
@@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* The input feature values must be nonnegative.
*/
class NaiveBayes(override val uid: String)
- extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
+ extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {
def this() = this(Identifiable.randomUID("nb"))
@@ -79,20 +79,21 @@ class NaiveBayes(override val uid: String)
* Default is 1.0.
* @group setParam
*/
- def setLambda(value: Double): this.type = set(lambda, value)
- setDefault(lambda -> 1.0)
+ def setSmoothing(value: Double): this.type = set(smoothing, value)
+ setDefault(smoothing -> 1.0)
/**
* Set the model type using a string (case-sensitive).
* Supported options: "multinomial" and "bernoulli".
* Default is "multinomial"
+ * @group setParam
*/
def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial)
override protected def train(dataset: DataFrame): NaiveBayesModel = {
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
- val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
+ val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)
}
@@ -101,12 +102,15 @@ class NaiveBayes(override val uid: String)
/**
* Model produced by [[NaiveBayes]]
+ * @param pi log of class priors, whose dimension is C (number of classes)
+ * @param theta log of class conditional probabilities, whose dimension is C (number of classes)
+ * by D (number of features)
*/
class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
- extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+ extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
import OldNaiveBayes.{Bernoulli, Multinomial}
@@ -129,29 +133,62 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
- override protected def predict(features: Vector): Double = {
+ override val numClasses: Int = pi.size
+
+ private def multinomialCalculation(features: Vector) = {
+ val prob = theta.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ prob
+ }
+
+ private def bernoulliCalculation(features: Vector) = {
+ features.foreachActive((_, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
+ }
+ )
+ val prob = thetaMinusNegTheta.get.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
$(modelType) match {
case Multinomial =>
- val prob = theta.multiply(features)
- BLAS.axpy(1.0, pi, prob)
- prob.argmax
+ multinomialCalculation(features)
case Bernoulli =>
- features.foreachActive{ (index, value) =>
- if (value != 0.0 && value != 1.0) {
- throw new SparkException(
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
- }
- }
- val prob = thetaMinusNegTheta.get.multiply(features)
- BLAS.axpy(1.0, pi, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
- prob.argmax
+ bernoulliCalculation(features)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
}
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ var i = 0
+ val size = dv.size
+ val maxLog = dv.values.max
+ while (i < size) {
+ dv.values(i) = math.exp(dv.values(i) - maxLog)
+ i += 1
+ }
+ val probSum = dv.values.sum
+ i = 0
+ while (i < size) {
+ dv.values(i) = dv.values(i) / probSum
+ i += 1
+ }
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 1741f19dc911c..1132d8046df67 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -138,7 +138,7 @@ final class OneVsRestModel private[ml] (
override def copy(extra: ParamMap): OneVsRestModel = {
val copied = new OneVsRestModel(
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index dad451108626d..fdd1851ae5508 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -20,17 +20,16 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* (private[classification]) Params for probabilistic classification.
*/
private[classification] trait ProbabilisticClassifierParams
- extends ClassifierParams with HasProbabilityCol {
-
+ extends ClassifierParams with HasProbabilityCol with HasThresholds {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
@@ -51,7 +50,7 @@ private[classification] trait ProbabilisticClassifierParams
* @tparam M Concrete Model type
*/
@DeveloperApi
-private[spark] abstract class ProbabilisticClassifier[
+abstract class ProbabilisticClassifier[
FeaturesType,
E <: ProbabilisticClassifier[FeaturesType, E, M],
M <: ProbabilisticClassificationModel[FeaturesType, M]]
@@ -59,6 +58,9 @@ private[spark] abstract class ProbabilisticClassifier[
/** @group setParam */
def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+
+ /** @group setParam */
+ def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
}
@@ -72,7 +74,7 @@ private[spark] abstract class ProbabilisticClassifier[
* @tparam M Concrete Model type
*/
@DeveloperApi
-private[spark] abstract class ProbabilisticClassificationModel[
+abstract class ProbabilisticClassificationModel[
FeaturesType,
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
@@ -80,6 +82,9 @@ private[spark] abstract class ProbabilisticClassificationModel[
/** @group setParam */
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
+ /** @group setParam */
+ def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
@@ -92,6 +97,11 @@ private[spark] abstract class ProbabilisticClassificationModel[
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".transform() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
@@ -155,6 +165,14 @@ private[spark] abstract class ProbabilisticClassificationModel[
raw2probabilityInPlace(probs)
}
+ override protected def raw2prediction(rawPrediction: Vector): Double = {
+ if (!isDefined(thresholds)) {
+ rawPrediction.argmax
+ } else {
+ probability2prediction(raw2probability(rawPrediction))
+ }
+ }
+
/**
* Predict the probability of each class given the features.
* These predictions are also called class conditional probabilities.
@@ -170,8 +188,44 @@ private[spark] abstract class ProbabilisticClassificationModel[
/**
* Given a vector of class conditional probabilities, select the predicted label.
- * This may be overridden to support thresholds which favor particular labels.
+ * This supports thresholds which favor particular labels.
* @return predicted label
*/
- protected def probability2prediction(probability: Vector): Double = probability.argmax
+ protected def probability2prediction(probability: Vector): Double = {
+ if (!isDefined(thresholds)) {
+ probability.argmax
+ } else {
+ val thresholds: Array[Double] = getThresholds
+ val scaledProbability: Array[Double] =
+ probability.toArray.zip(thresholds).map { case (p, t) =>
+ if (t == 0.0) Double.PositiveInfinity else p / t
+ }
+ Vectors.dense(scaledProbability).argmax
+ }
+ }
+}
+
+private[ml] object ProbabilisticClassificationModel {
+
+ /**
+ * Normalize a vector of raw predictions to be a multinomial probability vector, in place.
+ *
+ * The input raw predictions should be >= 0.
+ * The output vector sums to 1, unless the input vector is all-0 (in which case the output is
+ * all-0 too).
+ *
+ * NOTE: This is NOT applicable to all models, only ones which effectively use class
+ * instance counts for raw predictions.
+ */
+ def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = {
+ val sum = v.values.sum
+ if (sum != 0) {
+ var i = 0
+ val size = v.size
+ while (i < size) {
+ v.values(i) /= sum
+ i += 1
+ }
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index bc19bd6df894f..11a6d72468333 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,22 +17,19 @@
package org.apache.spark.ml.classification
-import scala.collection.mutable
-
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
/**
* :: Experimental ::
@@ -43,7 +40,7 @@ import org.apache.spark.sql.types.DoubleType
*/
@Experimental
final class RandomForestClassifier(override val uid: String)
- extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("rfc"))
@@ -98,7 +95,8 @@ final class RandomForestClassifier(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- new RandomForestClassificationModel(trees, numClasses)
+ val numFeatures = oldDataset.first().features.size
+ new RandomForestClassificationModel(trees, numFeatures, numClasses)
}
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -121,13 +119,15 @@ object RandomForestClassifier {
* features.
* @param _trees Decision trees in the ensemble.
* Warning: These have null parents.
+ * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestClassificationModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
+ val numFeatures: Int,
override val numClasses: Int)
- extends ClassificationModel[Vector, RandomForestClassificationModel]
+ extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable {
require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -136,8 +136,11 @@ final class RandomForestClassificationModel private[ml] (
* Construct a random forest classification model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
- this(Identifiable.randomUID("rfc"), trees, numClasses)
+ private[ml] def this(
+ trees: Array[DecisionTreeClassificationModel],
+ numFeatures: Int,
+ numClasses: Int) =
+ this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -157,23 +160,59 @@ final class RandomForestClassificationModel private[ml] (
override protected def predictRaw(features: Vector): Vector = {
// TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
// Classifies using majority votes.
- // Ignore the weights since all are 1.0 for now.
- val votes = new Array[Double](numClasses)
+ // Ignore the tree weights since all are 1.0 for now.
+ val votes = Array.fill[Double](numClasses)(0.0)
_trees.view.foreach { tree =>
- val prediction = tree.rootNode.predict(features).toInt
- votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
+ val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
+ val total = classCounts.sum
+ if (total != 0) {
+ var i = 0
+ while (i < numClasses) {
+ votes(i) += classCounts(i) / total
+ i += 1
+ }
+ }
}
Vectors.dense(votes)
}
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv)
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
override def copy(extra: ParamMap): RandomForestClassificationModel = {
- copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
+ copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
+ .setParent(parent)
}
override def toString: String = {
s"RandomForestClassificationModel with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ */
+ lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
@@ -195,6 +234,6 @@ private[ml] object RandomForestClassificationModel {
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
- new RandomForestClassificationModel(uid, newTrees, numClasses)
+ new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index dc192add6ca13..47a18cdb31b53 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -18,8 +18,8 @@
package org.apache.spark.ml.clustering
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap}
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed}
+import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
+import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
@@ -27,14 +27,13 @@ import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.util.Utils
/**
* Common params for KMeans and KMeansModel
*/
-private[clustering] trait KMeansParams
- extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
+private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
+ with HasSeed with HasPredictionCol with HasTol {
/**
* Set the number of clusters to create (k). Must be > 1. Default: 2.
@@ -45,31 +44,6 @@ private[clustering] trait KMeansParams
/** @group getParam */
def getK: Int = $(k)
- /**
- * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm
- * this many times with random starting conditions (configured by the initialization mode), then
- * return the best clustering found over any run. Must be >= 1. Default: 1.
- * @group param
- */
- final val runs = new IntParam(this, "runs",
- "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1)
-
- /** @group getParam */
- def getRuns: Int = $(runs)
-
- /**
- * Param the distance threshold within which we've consider centers to have converged.
- * If all centers move less than this Euclidean distance, we stop iterating one run.
- * Must be >= 0.0. Default: 1e-4
- * @group param
- */
- final val epsilon = new DoubleParam(this, "epsilon",
- "distance threshold within which we've consider centers to have converge",
- (value: Double) => value >= 0.0)
-
- /** @group getParam */
- def getEpsilon: Double = $(epsilon)
-
/**
* Param for the initialization algorithm. This can be either "random" to choose random points as
* initial cluster centers, or "k-means||" to use a parallel variant of k-means++
@@ -136,9 +110,9 @@ class KMeansModel private[ml] (
/**
* :: Experimental ::
- * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
- * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
- * they are executed together with joint passes over the data for efficiency.
+ * K-means clustering with support for k-means|| initialization proposed by Bahmani et al.
+ *
+ * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]]
*/
@Experimental
class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams {
@@ -146,10 +120,9 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean
setDefault(
k -> 2,
maxIter -> 20,
- runs -> 1,
initMode -> MLlibKMeans.K_MEANS_PARALLEL,
initSteps -> 5,
- epsilon -> 1e-4)
+ tol -> 1e-4)
override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
@@ -174,10 +147,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean
def setMaxIter(value: Int): this.type = set(maxIter, value)
/** @group setParam */
- def setRuns(value: Int): this.type = set(runs, value)
-
- /** @group setParam */
- def setEpsilon(value: Double): this.type = set(epsilon, value)
+ def setTol(value: Double): this.type = set(tol, value)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
@@ -191,8 +161,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean
.setInitializationSteps($(initSteps))
.setMaxIterations($(maxIter))
.setSeed($(seed))
- .setEpsilon($(epsilon))
- .setRuns($(runs))
+ .setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = new KMeansModel(uid, parentModel)
copyValues(model)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 4a82b77f0edcb..5d5cb7e94f45b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType
/**
* :: Experimental ::
- * Evaluator for binary classification, which expects two input columns: score and label.
+ * Evaluator for binary classification, which expects two input columns: rawPrediction and label.
*/
@Experimental
class BinaryClassificationEvaluator(override val uid: String)
@@ -50,6 +50,13 @@ class BinaryClassificationEvaluator(override val uid: String)
def setMetricName(value: String): this.type = set(metricName, value)
/** @group setParam */
+ def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
+ /**
+ * @group setParam
+ * @deprecated use [[setRawPredictionCol()]] instead
+ */
+ @deprecated("use setRawPredictionCol instead", "1.5.0")
def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
new file mode 100644
index 0000000000000..44f779c1908d7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.types.DoubleType
+
+/**
+ * :: Experimental ::
+ * Evaluator for multiclass classification, which expects two input columns: score and label.
+ */
+@Experimental
+class MulticlassClassificationEvaluator (override val uid: String)
+ extends Evaluator with HasPredictionCol with HasLabelCol {
+
+ def this() = this(Identifiable.randomUID("mcEval"))
+
+ /**
+ * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
+ * `"weightedPrecision"`, `"weightedRecall"`)
+ * @group param
+ */
+ val metricName: Param[String] = {
+ val allowedParams = ParamValidators.inArray(Array("f1", "precision",
+ "recall", "weightedPrecision", "weightedRecall"))
+ new Param(this, "metricName", "metric name in evaluation " +
+ "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
+ }
+
+ /** @group getParam */
+ def getMetricName: String = $(metricName)
+
+ /** @group setParam */
+ def setMetricName(value: String): this.type = set(metricName, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ setDefault(metricName -> "f1")
+
+ override def evaluate(dataset: DataFrame): Double = {
+ val schema = dataset.schema
+ SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+
+ val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
+ .map { case Row(prediction: Double, label: Double) =>
+ (prediction, label)
+ }
+ val metrics = new MulticlassMetrics(predictionAndLabels)
+ val metric = $(metricName) match {
+ case "f1" => metrics.weightedFMeasure
+ case "precision" => metrics.precision
+ case "recall" => metrics.recall
+ case "weightedPrecision" => metrics.weightedPrecision
+ case "weightedRecall" => metrics.weightedRecall
+ }
+ metric
+ }
+
+ override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 67e4785bc3553..cfca494dcf468 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
- override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
+ override def copy(extra: ParamMap): Bucketizer = {
+ defaultCopy[Bucketizer](extra).setParent(parent)
+ }
}
private[feature] object Bucketizer {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index ecde80810580c..938447447a0a2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -114,6 +114,6 @@ class IDFModel private[ml] (
override def copy(extra: ParamMap): IDFModel = {
val copied = new IDFModel(uid, idfModel)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index b30adf3df48d2..1b494ec8b1727 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -41,6 +41,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
val min: DoubleParam = new DoubleParam(this, "min",
"lower bound of the output feature range")
+ /** @group getParam */
+ def getMin: Double = $(min)
+
/**
* upper bound after transformation, shared by all features
* Default: 1.0
@@ -49,6 +52,9 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
val max: DoubleParam = new DoubleParam(this, "max",
"upper bound of the output feature range")
+ /** @group getParam */
+ def getMax: Double = $(max)
+
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
@@ -115,6 +121,9 @@ class MinMaxScaler(override val uid: String)
* :: Experimental ::
* Model fitted by [[MinMaxScaler]].
*
+ * @param originalMin min value for each original column during fitting
+ * @param originalMax max value for each original column during fitting
+ *
* TODO: The transformer does not yet set the metadata in the output column (SPARK-8529).
*/
@Experimental
@@ -136,7 +145,6 @@ class MinMaxScalerModel private[ml] (
/** @group setParam */
def setMax(value: Double): this.type = set(max, value)
-
override def transform(dataset: DataFrame): DataFrame = {
val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
val minArray = originalMin.toArray
@@ -165,6 +173,6 @@ class MinMaxScalerModel private[ml] (
override def copy(extra: ParamMap): MinMaxScalerModel = {
val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 3825942795645..9c60d4084ec46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
- val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
@@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
val outputAttrNames: Option[Array[String]] = inputAttr match {
case nominal: NominalAttribute =>
if (nominal.values.isDefined) {
- nominal.values.map(_.map(v => inputColName + is + v))
+ nominal.values
} else if (nominal.numValues.isDefined) {
- nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i))
+ nominal.numValues.map(n => Array.tabulate(n)(_.toString))
} else {
None
}
case binary: BinaryAttribute =>
if (binary.values.isDefined) {
- binary.values.map(_.map(v => inputColName + is + v))
+ binary.values
} else {
- Some(Array.tabulate(2)(i => inputColName + is + i))
+ Some(Array.tabulate(2)(_.toString))
}
case _: NumericAttribute =>
throw new RuntimeException(
@@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
override def transform(dataset: DataFrame): DataFrame = {
// schema transformation
- val is = "_is_"
val inputColName = $(inputCol)
val outputColName = $(outputCol)
val shouldDropLast = $(dropLast)
@@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
math.max(m0, m1)
}
).toInt + 1
- val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i)
+ val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames
val outputAttrs: Array[Attribute] =
filtered.map(name => BinaryAttribute.defaultAttr.withName(name))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 2d3bb680cf309..539084704b653 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -125,6 +125,6 @@ class PCAModel private[ml] (
override def copy(extra: ParamMap): PCAModel = {
val copied = new PCAModel(uid, pcaModel)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 0b428d278d908..a752dacd72d95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -17,28 +17,22 @@
package org.apache.spark.ml.feature
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
* Base trait for [[RFormula]] and [[RFormulaModel]].
*/
private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
- /** @group getParam */
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
-
- /** @group getParam */
- def setLabelCol(value: String): this.type = set(labelCol, value)
protected def hasLabelCol(schema: StructType): Boolean = {
schema.map(_.name).contains($(labelCol))
@@ -62,40 +56,50 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
*/
val formula: Param[String] = new Param(this, "formula", "R model formula")
- private var parsedFormula: Option[ParsedRFormula] = None
-
/**
* Sets the formula to use for this transformer. Must be called before use.
* @group setParam
* @param value an R formula in string form (e.g. "y ~ x + z")
*/
- def setFormula(value: String): this.type = {
- parsedFormula = Some(RFormulaParser.parse(value))
- set(formula, value)
- this
- }
+ def setFormula(value: String): this.type = set(formula, value)
/** @group getParam */
def getFormula: String = $(formula)
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
- parsedFormula.get.hasIntercept
+ require(isDefined(formula), "Formula must be defined first.")
+ RFormulaParser.parse($(formula)).hasIntercept
}
override def fit(dataset: DataFrame): RFormulaModel = {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
- val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
+ require(isDefined(formula), "Formula must be defined first.")
+ val parsedFormula = RFormulaParser.parse($(formula))
+ val resolvedFormula = parsedFormula.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()
val tempColumns = ArrayBuffer[String]()
+ val takenNames = mutable.Set(dataset.columns: _*)
val encodedTerms = resolvedFormula.terms.map { term =>
dataset.schema(term) match {
case column if column.dataType == StringType =>
val indexCol = term + "_idx_" + uid
- val encodedCol = term + "_onehot_" + uid
+ val encodedCol = {
+ var tmp = term
+ while (takenNames.contains(tmp)) {
+ tmp += "_"
+ }
+ tmp
+ }
+ takenNames.add(indexCol)
+ takenNames.add(encodedCol)
encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol)
encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
tempColumns += indexCol
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
new file mode 100644
index 0000000000000..95e4305638730
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{ParamMap, Param}
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.{SQLContext, DataFrame, Row}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ * Implements the transforms which are defined by SQL statement.
+ * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
+ * where '__THIS__' represents the underlying table of the input dataset.
+ */
+@Experimental
+class SQLTransformer (override val uid: String) extends Transformer {
+
+ def this() = this(Identifiable.randomUID("sql"))
+
+ /**
+ * SQL statement parameter. The statement is provided in string form.
+ * @group param
+ */
+ final val statement: Param[String] = new Param[String](this, "statement", "SQL statement")
+
+ /** @group setParam */
+ def setStatement(value: String): this.type = set(statement, value)
+
+ /** @group getParam */
+ def getStatement: String = $(statement)
+
+ private val tableIdentifier: String = "__THIS__"
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val tableName = Identifiable.randomUID(uid)
+ dataset.registerTempTable(tableName)
+ val realStatement = $(statement).replace(tableIdentifier, tableName)
+ val outputDF = dataset.sqlContext.sql(realStatement)
+ outputDF
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val dummyRDD = sc.parallelize(Seq(Row.empty))
+ val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
+ dummyDF.registerTempTable(tableIdentifier)
+ val outputSchema = sqlContext.sql($(statement)).schema
+ outputSchema
+ }
+
+ override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 72b545e5db3e4..f6d0b0c0e9e75 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -136,6 +136,6 @@ class StandardScalerModel private[ml] (
override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, scaler)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
new file mode 100644
index 0000000000000..5d77ea08db657
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
+
+/**
+ * stop words list
+ */
+private object StopWords {
+
+ /**
+ * Use the same default stopwords list as scikit-learn.
+ * The original list can be found from "Glasgow Information Retrieval Group"
+ * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]]
+ */
+ val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again",
+ "against", "all", "almost", "alone", "along", "already", "also", "although", "always",
+ "am", "among", "amongst", "amoungst", "amount", "an", "and", "another",
+ "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are",
+ "around", "as", "at", "back", "be", "became", "because", "become",
+ "becomes", "becoming", "been", "before", "beforehand", "behind", "being",
+ "below", "beside", "besides", "between", "beyond", "bill", "both",
+ "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con",
+ "could", "couldnt", "cry", "de", "describe", "detail", "do", "done",
+ "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else",
+ "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone",
+ "everything", "everywhere", "except", "few", "fifteen", "fify", "fill",
+ "find", "fire", "first", "five", "for", "former", "formerly", "forty",
+ "found", "four", "from", "front", "full", "further", "get", "give", "go",
+ "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter",
+ "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his",
+ "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed",
+ "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter",
+ "latterly", "least", "less", "ltd", "made", "many", "may", "me",
+ "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly",
+ "move", "much", "must", "my", "myself", "name", "namely", "neither",
+ "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone",
+ "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on",
+ "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our",
+ "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps",
+ "please", "put", "rather", "re", "same", "see", "seem", "seemed",
+ "seeming", "seems", "serious", "several", "she", "should", "show", "side",
+ "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone",
+ "something", "sometime", "sometimes", "somewhere", "still", "such",
+ "system", "take", "ten", "than", "that", "the", "their", "them",
+ "themselves", "then", "thence", "there", "thereafter", "thereby",
+ "therefore", "therein", "thereupon", "these", "they", "thick", "thin",
+ "third", "this", "those", "though", "three", "through", "throughout",
+ "thru", "thus", "to", "together", "too", "top", "toward", "towards",
+ "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us",
+ "very", "via", "was", "we", "well", "were", "what", "whatever", "when",
+ "whence", "whenever", "where", "whereafter", "whereas", "whereby",
+ "wherein", "whereupon", "wherever", "whether", "which", "while", "whither",
+ "who", "whoever", "whole", "whom", "whose", "why", "will", "with",
+ "within", "without", "would", "yet", "you", "your", "yours", "yourself", "yourselves")
+}
+
+/**
+ * :: Experimental ::
+ * A feature transformer that filters out stop words from input.
+ * Note: null values from input array are preserved unless adding null to stopWords explicitly.
+ * @see [[http://en.wikipedia.org/wiki/Stop_words]]
+ */
+@Experimental
+class StopWordsRemover(override val uid: String)
+ extends Transformer with HasInputCol with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("stopWords"))
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /**
+ * the stop words set to be filtered out
+ * @group param
+ */
+ val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words")
+
+ /** @group setParam */
+ def setStopWords(value: Array[String]): this.type = set(stopWords, value)
+
+ /** @group getParam */
+ def getStopWords: Array[String] = $(stopWords)
+
+ /**
+ * whether to do a case sensitive comparison over the stop words
+ * @group param
+ */
+ val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive",
+ "whether to do case-sensitive comparison during filtering")
+
+ /** @group setParam */
+ def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value)
+
+ /** @group getParam */
+ def getCaseSensitive: Boolean = $(caseSensitive)
+
+ setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false)
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema)
+ val t = if ($(caseSensitive)) {
+ val stopWordsSet = $(stopWords).toSet
+ udf { terms: Seq[String] =>
+ terms.filter(s => !stopWordsSet.contains(s))
+ }
+ } else {
+ val toLower = (s: String) => if (s != null) s.toLowerCase else s
+ val lowerStopWords = $(stopWords).map(toLower(_)).toSet
+ udf { terms: Seq[String] =>
+ terms.filter(s => !lowerStopWords.contains(toLower(s)))
+ }
+ }
+
+ val metadata = outputSchema($(outputCol)).metadata
+ dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
+ require(inputType.sameType(ArrayType(StringType)),
+ s"Input type must be ArrayType(StringType) but got $inputType.")
+ val outputFields = schema.fields :+
+ StructField($(outputCol), inputType, schema($(inputCol)).nullable)
+ StructType(outputFields)
+ }
+
+ override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index bf7be363b8224..9f6e7b6b6b274 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -20,19 +20,21 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{NumericType, StringType, StructType}
+import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
-private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
+ with HasHandleInvalid {
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -57,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
+ *
+ * @see [[IndexToString]] for the inverse transformation
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
@@ -64,13 +68,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
def this() = this(Identifiable.randomUID("strIdx"))
+ /** @group setParam */
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+ setDefault(handleInvalid, "error")
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- // TODO: handle unseen labels
override def fit(dataset: DataFrame): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
@@ -110,6 +117,10 @@ class StringIndexerModel private[ml] (
map
}
+ /** @group setParam */
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+ setDefault(handleInvalid, "error")
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -127,14 +138,24 @@ class StringIndexerModel private[ml] (
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
- // TODO: handle unseen labels
throw new SparkException(s"Unseen label: $label.")
}
}
+
val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
- dataset.select(col("*"),
+ // If we are skipping invalid records, filter them out.
+ val filteredDataset = (getHandleInvalid) match {
+ case "skip" => {
+ val filterer = udf { label: String =>
+ labelToIndex.contains(label)
+ }
+ dataset.where(filterer(dataset($(inputCol))))
+ }
+ case _ => dataset
+ }
+ filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
}
@@ -149,6 +170,97 @@ class StringIndexerModel private[ml] (
override def copy(extra: ParamMap): StringIndexerModel = {
val copied = new StringIndexerModel(uid, labels)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * A [[Transformer]] that maps a column of string indices back to a new column of corresponding
+ * string values using either the ML attributes of the input column, or if provided using the labels
+ * supplied by the user.
+ * All original columns are kept during transformation.
+ *
+ * @see [[StringIndexer]] for converting strings into indices
+ */
+@Experimental
+class IndexToString private[ml] (
+ override val uid: String) extends Transformer
+ with HasInputCol with HasOutputCol {
+
+ def this() =
+ this(Identifiable.randomUID("idxToStr"))
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /**
+ * Optional labels to be provided by the user, if not supplied column
+ * metadata is read for labels. The default value is an empty array,
+ * but the empty array is ignored and column metadata used instead.
+ * @group setParam
+ */
+ def setLabels(value: Array[String]): this.type = set(labels, value)
+
+ /**
+ * Param for array of labels.
+ * Optional labels to be provided by the user, if not supplied column
+ * metadata is read for labels.
+ * @group param
+ */
+ final val labels: StringArrayParam = new StringArrayParam(this, "labels",
+ "array of labels, if not provided metadata from inputCol is used instead.")
+ setDefault(labels, Array.empty[String])
+
+ /**
+ * Optional labels to be provided by the user, if not supplied column
+ * metadata is read for labels.
+ * @group getParam
+ */
+ final def getLabels: Array[String] = $(labels)
+
+ /** Transform the schema for the inverse transformation */
+ override def transformSchema(schema: StructType): StructType = {
+ val inputColName = $(inputCol)
+ val inputDataType = schema(inputColName).dataType
+ require(inputDataType.isInstanceOf[NumericType],
+ s"The input column $inputColName must be a numeric type, " +
+ s"but got $inputDataType.")
+ val inputFields = schema.fields
+ val outputColName = $(outputCol)
+ require(inputFields.forall(_.name != outputColName),
+ s"Output column $outputColName already exists.")
+ val attr = NominalAttribute.defaultAttr.withName($(outputCol))
+ val outputFields = inputFields :+ attr.toStructField()
+ StructType(outputFields)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val inputColSchema = dataset.schema($(inputCol))
+ // If the labels array is empty use column metadata
+ val values = if ($(labels).isEmpty) {
+ Attribute.fromStructField(inputColSchema)
+ .asInstanceOf[NominalAttribute].values.get
+ } else {
+ $(labels)
+ }
+ val indexer = udf { index: Double =>
+ val idx = index.toInt
+ if (0 <= idx && idx < values.length) {
+ values(idx)
+ } else {
+ throw new SparkException(s"Unseen index: $index ??")
+ }
+ }
+ val outputColName = $(outputCol)
+ dataset.select(col("*"),
+ indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
+ }
+
+ override def copy(extra: ParamMap): IndexToString = {
+ defaultCopy(extra)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index c73bdccdef5fa..6875aefe065bb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] (
override def copy(extra: ParamMap): VectorIndexerModel = {
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
new file mode 100644
index 0000000000000..772bebeff214b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ * This class takes a feature vector and outputs a new feature vector with a subarray of the
+ * original features.
+ *
+ * The subset of features can be specified with either indices ([[setIndices()]])
+ * or names ([[setNames()]]). At least one feature must be selected. Duplicate features
+ * are not allowed, so there can be no overlap between selected indices and names.
+ *
+ * The output vector will order features with the selected indices first (in the order given),
+ * followed by the selected names (in the order given).
+ */
+@Experimental
+final class VectorSlicer(override val uid: String)
+ extends Transformer with HasInputCol with HasOutputCol {
+
+ def this() = this(Identifiable.randomUID("vectorSlicer"))
+
+ /**
+ * An array of indices to select features from a vector column.
+ * There can be no overlap with [[names]].
+ * @group param
+ */
+ val indices = new IntArrayParam(this, "indices",
+ "An array of indices to select features from a vector column." +
+ " There can be no overlap with names.", VectorSlicer.validIndices)
+
+ setDefault(indices -> Array.empty[Int])
+
+ /** @group getParam */
+ def getIndices: Array[Int] = $(indices)
+
+ /** @group setParam */
+ def setIndices(value: Array[Int]): this.type = set(indices, value)
+
+ /**
+ * An array of feature names to select features from a vector column.
+ * These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
+ * There can be no overlap with [[indices]].
+ * @group param
+ */
+ val names = new StringArrayParam(this, "names",
+ "An array of feature names to select features from a vector column." +
+ " There can be no overlap with indices.", VectorSlicer.validNames)
+
+ setDefault(names -> Array.empty[String])
+
+ /** @group getParam */
+ def getNames: Array[String] = $(names)
+
+ /** @group setParam */
+ def setNames(value: Array[String]): this.type = set(names, value)
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def validateParams(): Unit = {
+ require($(indices).length > 0 || $(names).length > 0,
+ s"VectorSlicer requires that at least one feature be selected.")
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ // Validity checks
+ transformSchema(dataset.schema)
+ val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
+ inputAttr.numAttributes.foreach { numFeatures =>
+ val maxIndex = $(indices).max
+ require(maxIndex < numFeatures,
+ s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
+ }
+
+ // Prepare output attributes
+ val inds = getSelectedFeatureIndices(dataset.schema)
+ val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
+ inds.map(index => attrs(index))
+ }
+ val outputAttr = selectedAttrs match {
+ case Some(attrs) => new AttributeGroup($(outputCol), attrs)
+ case None => new AttributeGroup($(outputCol), inds.length)
+ }
+
+ // Select features
+ val slicer = udf { vec: Vector =>
+ vec match {
+ case features: DenseVector => Vectors.dense(inds.map(features.apply))
+ case features: SparseVector => features.slice(inds)
+ }
+ }
+ dataset.withColumn($(outputCol),
+ slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
+ }
+
+ /** Get the feature indices in order: indices, names */
+ private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
+ val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
+ val indFeatures = $(indices)
+ val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
+ lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
+ s" sets of features, but they overlap." +
+ s" indices: ${indFeatures.mkString("[", ",", "]")}." +
+ s" names: " +
+ nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
+ require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
+ indFeatures ++ nameFeatures
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+
+ if (schema.fieldNames.contains($(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
+ }
+ val numFeaturesSelected = $(indices).length + $(names).length
+ val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
+ val outputFields = schema.fields :+ outputAttr.toStructField()
+ StructType(outputFields)
+ }
+
+ override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
+}
+
+private[feature] object VectorSlicer {
+
+ /** Return true if given feature indices are valid */
+ def validIndices(indices: Array[Int]): Boolean = {
+ if (indices.isEmpty) {
+ true
+ } else {
+ indices.length == indices.distinct.length && indices.forall(_ >= 0)
+ }
+ }
+
+ /** Return true if given feature names are valid */
+ def validNames(names: Array[String]): Boolean = {
+ names.forall(_.nonEmpty) && names.length == names.distinct.length
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 6ea6590956300..5af775a4159ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -18,15 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
+import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
@@ -146,6 +148,40 @@ class Word2VecModel private[ml] (
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
+
+ /**
+ * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
+ * and the vector the DenseVector that it is mapped to.
+ */
+ @transient lazy val getVectors: DataFrame = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
+ sc.parallelize(wordVec.toSeq).toDF("word", "vector")
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word.
+ * Returns a dataframe with the words and the cosine similarities between the
+ * synonyms and the given word.
+ */
+ def findSynonyms(word: String, num: Int): DataFrame = {
+ findSynonyms(wordVectors.transform(word), num)
+ }
+
+ /**
+ * Find "num" number of words closest to similarity to the given vector representation
+ * of the word. Returns a dataframe with the words and the cosine similarities between the
+ * synonyms and the given word vector.
+ */
+ def findSynonyms(word: Vector, num: Int): DataFrame = {
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ }
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -185,6 +221,6 @@ class Word2VecModel private[ml] (
override def copy(extra: ParamMap): Word2VecModel = {
val copied = new Word2VecModel(uid, wordVectors)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 954aa17e26a02..91c0a5631319d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -166,6 +166,11 @@ object ParamValidators {
def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
allowed.contains(value)
}
+
+ /** Check that the array length is greater than lowerBound. */
+ def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) =>
+ value.length > lowerBound
+ }
}
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
@@ -554,13 +559,26 @@ trait Params extends Identifiable with Serializable {
/**
* Copies param values from this instance to another instance for params shared by them.
- * @param to the target instance
- * @param extra extra params to be copied
+ *
+ * This handles default Params and explicitly set Params separately.
+ * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are
+ * copied from and to [[paramMap]].
+ * Warning: This implicitly assumes that this [[Params]] instance and the target instance
+ * share the same set of default Params.
+ *
+ * @param to the target instance, which should work with the same set of default Params as this
+ * source instance
+ * @param extra extra params to be copied to the target's [[paramMap]]
* @return the target instance with param values copied
*/
protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
- val map = extractParamMap(extra)
+ val map = paramMap ++ extra
params.foreach { param =>
+ // copy default Params
+ if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
+ to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
+ }
+ // copy explicitly set Params
if (map.contains(param) && to.hasParam(param.name)) {
to.set(param.name, map(param))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index f7ae1de522e01..8c16c6149b40d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -45,14 +45,24 @@ private[shared] object SharedParamsCodeGen {
" These probabilities should be treated as confidences, not precise probabilities.",
Some("\"probability\"")),
ParamDesc[Double]("threshold",
- "threshold in binary classification prediction, in range [0, 1]",
- isValid = "ParamValidators.inRange(0, 1)"),
+ "threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
+ isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
+ ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
+ " to adjust the probability of predicting each class." +
+ " Array must have length equal to the number of classes, with values >= 0." +
+ " The class with largest value p/t is predicted, where p is the original probability" +
+ " of that class and t is the class' threshold.",
+ isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
+ ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
+ "will filter out rows with bad values), or error (which will throw an errror). More " +
+ "options may be added later.",
+ isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
" before fitting the model.", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
@@ -60,7 +70,9 @@ private[shared] object SharedParamsCodeGen {
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
isValid = "ParamValidators.inRange(0, 1)"),
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
- ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
+ ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."),
+ ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
+ "all instance weights as 1.0."))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
@@ -74,7 +86,8 @@ private[shared] object SharedParamsCodeGen {
name: String,
doc: String,
defaultValueStr: Option[String] = None,
- isValid: String = "") {
+ isValid: String = "",
+ finalMethods: Boolean = true) {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
require(doc.nonEmpty) // TODO: more rigorous on doc
@@ -88,6 +101,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
+ case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam"
case _ => s"Param[${getTypeString(c)}]"
}
}
@@ -131,6 +145,11 @@ private[shared] object SharedParamsCodeGen {
} else {
""
}
+ val methodStr = if (param.finalMethods) {
+ "final def"
+ } else {
+ "def"
+ }
s"""
|/**
@@ -145,7 +164,7 @@ private[shared] object SharedParamsCodeGen {
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault
| /** @group getParam */
- | final def get$Name: $T = $$($name)
+ | $methodStr get$Name: $T = $$($name)
|}
|""".stripMargin
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 65e48e4ee5083..c26768953e3db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
}
/**
- * Trait for shared param threshold.
+ * Trait for shared param threshold (default: 0.5).
*/
private[ml] trait HasThreshold extends Params {
@@ -149,8 +149,25 @@ private[ml] trait HasThreshold extends Params {
*/
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
+ setDefault(threshold, 0.5)
+
+ /** @group getParam */
+ def getThreshold: Double = $(threshold)
+}
+
+/**
+ * Trait for shared param thresholds.
+ */
+private[ml] trait HasThresholds extends Params {
+
+ /**
+ * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
+ * @group param
+ */
+ final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))
+
/** @group getParam */
- final def getThreshold: Double = $(threshold)
+ def getThresholds: Array[Double] = $(thresholds)
}
/**
@@ -232,6 +249,21 @@ private[ml] trait HasFitIntercept extends Params {
final def getFitIntercept: Boolean = $(fitIntercept)
}
+/**
+ * Trait for shared param handleInvalid.
+ */
+private[ml] trait HasHandleInvalid extends Params {
+
+ /**
+ * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
+ * @group param
+ */
+ final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))
+
+ /** @group getParam */
+ final def getHandleInvalid: String = $(handleInvalid)
+}
+
/**
* Trait for shared param standardization (default: true).
*/
@@ -310,4 +342,19 @@ private[ml] trait HasStepSize extends Params {
/** @group getParam */
final def getStepSize: Double = $(stepSize)
}
+
+/**
+ * Trait for shared param weightCol.
+ */
+private[ml] trait HasWeightCol extends Params {
+
+ /**
+ * Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0..
+ * @group param
+ */
+ final val weightCol: Param[String] = new Param[String](this, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.")
+
+ /** @group getParam */
+ final def getWeightCol: String = $(weightCol)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 9f70592ccad7e..f5a022c31ed90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -17,9 +17,10 @@
package org.apache.spark.ml.api.r
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.RFormula
-import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame
@@ -44,4 +45,26 @@ private[r] object SparkRWrappers {
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}
+
+ def getModelWeights(model: PipelineModel): Array[Double] = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ Array(m.intercept) ++ m.weights.toArray
+ case _: LogisticRegressionModel =>
+ throw new UnsupportedOperationException(
+ "No weights available for LogisticRegressionModel") // SPARK-9492
+ }
+ }
+
+ def getModelFeatures(model: PipelineModel): Array[String] = {
+ model.stages.last match {
+ case m: LinearRegressionModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
+ case _: LogisticRegressionModel =>
+ throw new UnsupportedOperationException(
+ "No features names available for LogisticRegressionModel") // SPARK-9492
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 2e44cd4cc6a22..7db8ad8d27918 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -219,7 +219,7 @@ class ALSModel private[ml] (
override def copy(extra: ParamMap): ALSModel = {
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6f3340c2f02be..a2bcd67401d08 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -107,14 +107,14 @@ final class DecisionTreeRegressionModel private[ml] (
* Construct a decision tree regression model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
+ private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
override protected def predict(features: Vector): Double = {
- rootNode.predict(features)
+ rootNode.predictImpl(features).prediction
}
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
- copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra)
+ copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index e38dc73ee0ba7..b66e61f37dd5e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -180,12 +180,12 @@ final class GBTRegressionModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
- val treePredictions = _trees.map(_.rootNode.predict(features))
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
override def copy(extra: ParamMap): GBTRegressionModel = {
- copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra)
+ copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
new file mode 100644
index 0000000000000..0f33bae30e622
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression, IsotonicRegressionModel => MLlibIsotonicRegressionModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, lit, udf}
+import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for isotonic regression.
+ */
+private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol
+ with HasLabelCol with HasPredictionCol with HasWeightCol with Logging {
+
+ /**
+ * Param for whether the output sequence should be isotonic/increasing (true) or
+ * antitonic/decreasing (false).
+ * @group param
+ */
+ final val isotonic: BooleanParam =
+ new BooleanParam(this, "isotonic",
+ "whether the output sequence should be isotonic/increasing (true) or" +
+ "antitonic/decreasing (false)")
+
+ /** @group getParam */
+ final def getIsotonic: Boolean = $(isotonic)
+
+ /**
+ * Param for the index of the feature if [[featuresCol]] is a vector column (default: `0`), no
+ * effect otherwise.
+ * @group param
+ */
+ final val featureIndex: IntParam = new IntParam(this, "featureIndex",
+ "The index of the feature if featuresCol is a vector column, no effect otherwise.")
+
+ /** @group getParam */
+ final def getFeatureIndex: Int = $(featureIndex)
+
+ setDefault(isotonic -> true, featureIndex -> 0)
+
+ /** Checks whether the input has weight column. */
+ protected[ml] def hasWeightCol: Boolean = {
+ isDefined(weightCol) && $(weightCol) != ""
+ }
+
+ /**
+ * Extracts (label, feature, weight) from input dataset.
+ */
+ protected[ml] def extractWeightedLabeledPoints(
+ dataset: DataFrame): RDD[(Double, Double, Double)] = {
+ val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
+ val idx = $(featureIndex)
+ val extract = udf { v: Vector => v(idx) }
+ extract(col($(featuresCol)))
+ } else {
+ col($(featuresCol))
+ }
+ val w = if (hasWeightCol) {
+ col($(weightCol))
+ } else {
+ lit(1.0)
+ }
+ dataset.select(col($(labelCol)), f, w)
+ .map { case Row(label: Double, feature: Double, weights: Double) =>
+ (label, feature, weights)
+ }
+ }
+
+ /**
+ * Validates and transforms input schema.
+ * @param schema input schema
+ * @param fitting whether this is in fitting or prediction
+ * @return output schema
+ */
+ protected[ml] def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean): StructType = {
+ if (fitting) {
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ if (hasWeightCol) {
+ SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
+ } else {
+ logInfo("The weight column is not defined. Treat all instance weights as 1.0.")
+ }
+ }
+ val featuresType = schema($(featuresCol)).dataType
+ require(featuresType == DoubleType || featuresType.isInstanceOf[VectorUDT])
+ SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Isotonic regression.
+ *
+ * Currently implemented using parallelized pool adjacent violators algorithm.
+ * Only univariate (single feature) algorithm supported.
+ *
+ * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]].
+ */
+@Experimental
+class IsotonicRegression(override val uid: String) extends Estimator[IsotonicRegressionModel]
+ with IsotonicRegressionBase {
+
+ def this() = this(Identifiable.randomUID("isoReg"))
+
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setIsotonic(value: Boolean): this.type = set(isotonic, value)
+
+ /** @group setParam */
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
+ /** @group setParam */
+ def setFeatureIndex(value: Int): this.type = set(featureIndex, value)
+
+ override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
+
+ override def fit(dataset: DataFrame): IsotonicRegressionModel = {
+ validateAndTransformSchema(dataset.schema, fitting = true)
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val instances = extractWeightedLabeledPoints(dataset)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
+ val oldModel = isotonicRegression.run(instances)
+
+ copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by IsotonicRegression.
+ * Predicts using a piecewise linear function.
+ *
+ * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]].
+ *
+ * @param oldModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]]
+ * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]].
+ */
+@Experimental
+class IsotonicRegressionModel private[ml] (
+ override val uid: String,
+ private val oldModel: MLlibIsotonicRegressionModel)
+ extends Model[IsotonicRegressionModel] with IsotonicRegressionBase {
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setFeatureIndex(value: Int): this.type = set(featureIndex, value)
+
+ /** Boundaries in increasing order for which predictions are known. */
+ def boundaries: Vector = Vectors.dense(oldModel.boundaries)
+
+ /**
+ * Predictions associated with the boundaries at the same index, monotone because of isotonic
+ * regression.
+ */
+ def predictions: Vector = Vectors.dense(oldModel.predictions)
+
+ override def copy(extra: ParamMap): IsotonicRegressionModel = {
+ copyValues(new IsotonicRegressionModel(uid, oldModel), extra)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val predict = dataset.schema($(featuresCol)).dataType match {
+ case DoubleType =>
+ udf { feature: Double => oldModel.predict(feature) }
+ case _: VectorUDT =>
+ val idx = $(featureIndex)
+ udf { features: Vector => oldModel.predict(features(idx)) }
+ }
+ dataset.withColumn($(predictionCol), predict(col($(featuresCol))))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = false)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 89718e0f3e15a..884003eb38524 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructField
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.StatCounter
@@ -44,7 +45,7 @@ import org.apache.spark.util.StatCounter
*/
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
- with HasFitIntercept
+ with HasFitIntercept with HasStandardization
/**
* :: Experimental ::
@@ -83,6 +84,18 @@ class LinearRegression(override val uid: String)
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
+ /**
+ * Whether to standardize the training features before fitting the model.
+ * The coefficients of models will be always returned on the original scale,
+ * so it will be transparent for users. Note that with/without standardization,
+ * the models should be always converged to the same solution when no regularization
+ * is applied. In R's GLMNET package, the default behavior is true as well.
+ * Default is true.
+ * @group setParam
+ */
+ def setStandardization(value: Boolean): this.type = set(standardization, value)
+ setDefault(standardization -> true)
+
/**
* Set the ElasticNet mixing parameter.
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
@@ -146,9 +159,10 @@ class LinearRegression(override val uid: String)
val model = new LinearRegressionModel(uid, weights, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
- model.transform(dataset).select($(predictionCol), $(labelCol)),
+ model.transform(dataset),
$(predictionCol),
$(labelCol),
+ $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
}
@@ -163,12 +177,24 @@ class LinearRegression(override val uid: String)
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
- featuresStd, featuresMean, effectiveL2RegParam)
+ $(standardization), featuresStd, featuresMean, effectiveL2RegParam)
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
- new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol))
+ def effectiveL1RegFun = (index: Int) => {
+ if ($(standardization)) {
+ effectiveL1RegParam
+ } else {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0
+ }
+ }
+ new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
}
val initialWeights = Vectors.zeros(numFeatures)
@@ -221,9 +247,10 @@ class LinearRegression(override val uid: String)
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
- model.transform(dataset).select($(predictionCol), $(labelCol)),
+ model.transform(dataset),
$(predictionCol),
$(labelCol),
+ $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -285,7 +312,7 @@ class LinearRegressionModel private[ml] (
override def copy(extra: ParamMap): LinearRegressionModel = {
val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
- newModel
+ newModel.setParent(parent)
}
}
@@ -300,6 +327,7 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ val featuresCol: String,
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(predictions, predictionCol, labelCol) {
@@ -452,6 +480,7 @@ class LinearRegressionSummary private[regression] (
* @param weights The weights/coefficients corresponding to the features.
* @param labelStd The standard deviation value of the label.
* @param labelMean The mean value of the label.
+ * @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features.
* @param featuresMean The mean values of the features.
*/
@@ -564,6 +593,7 @@ private class LeastSquaresCostFun(
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
+ standardization: Boolean,
featuresStd: Array[Double],
featuresMean: Array[Double],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
@@ -580,14 +610,38 @@ private class LeastSquaresCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})
- // regVal is the sum of weight squares for L2 regularization
- val norm = brzNorm(weights, 2.0)
- val regVal = 0.5 * effectiveL2regParam * norm * norm
+ val totalGradientArray = leastSquaresAggregator.gradient.toArray
- val loss = leastSquaresAggregator.loss + regVal
- val gradient = leastSquaresAggregator.gradient
- axpy(effectiveL2regParam, w, gradient)
+ val regVal = if (effectiveL2regParam == 0.0) {
+ 0.0
+ } else {
+ var sum = 0.0
+ w.foreachActive { (index, value) =>
+ // The following code will compute the loss of the regularization; also
+ // the gradient of the regularization, and add back to totalGradientArray.
+ sum += {
+ if (standardization) {
+ totalGradientArray(index) += effectiveL2regParam * value
+ value * value
+ } else {
+ if (featuresStd(index) != 0.0) {
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
+ val temp = value / (featuresStd(index) * featuresStd(index))
+ totalGradientArray(index) += effectiveL2regParam * temp
+ value * temp
+ } else {
+ 0.0
+ }
+ }
+ }
+ }
+ 0.5 * effectiveL2regParam * sum
+ }
- (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
+ (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 506a878c2553b..2f36da371f577 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
+
/**
* :: Experimental ::
@@ -87,7 +87,8 @@ final class RandomForestRegressor(override val uid: String)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
- new RandomForestRegressionModel(trees)
+ val numFeatures = oldDataset.first().features.size
+ new RandomForestRegressionModel(trees, numFeatures)
}
override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra)
@@ -108,11 +109,13 @@ object RandomForestRegressor {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
* @param _trees Decision trees in the ensemble.
+ * @param numFeatures Number of features used by this model
*/
@Experimental
final class RandomForestRegressionModel private[ml] (
override val uid: String,
- private val _trees: Array[DecisionTreeRegressionModel])
+ private val _trees: Array[DecisionTreeRegressionModel],
+ val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
with TreeEnsembleModel with Serializable {
@@ -122,7 +125,8 @@ final class RandomForestRegressionModel private[ml] (
* Construct a random forest regression model, with all trees weighted equally.
* @param trees Component trees
*/
- def this(trees: Array[DecisionTreeRegressionModel]) = this(Identifiable.randomUID("rfr"), trees)
+ private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
+ this(Identifiable.randomUID("rfr"), trees, numFeatures)
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@@ -143,17 +147,34 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predict(features)).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
- copyValues(new RandomForestRegressionModel(uid, _trees), extra)
+ copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
}
override def toString: String = {
s"RandomForestRegressionModel with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ */
+ lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
@@ -173,6 +194,6 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees)
+ new RandomForestRegressionModel(parent.uid, newTrees, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index bbc2427ca7d3d..cd24931293903 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
- Node => OldNode, Predict => OldPredict}
+ Node => OldNode, Predict => OldPredict, ImpurityStats}
/**
* :: DeveloperApi ::
@@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
/** Impurity measure at this node (for training data) */
def impurity: Double
+ /**
+ * Statistics aggregated from training data at this node, used to compute prediction, impurity,
+ * and probabilities.
+ * For classification, the array of class counts must be normalized to a probability distribution.
+ */
+ private[ml] def impurityStats: ImpurityCalculator
+
/** Recursive prediction helper method */
- private[ml] def predict(features: Vector): Double = prediction
+ private[ml] def predictImpl(features: Vector): LeafNode
/**
* Get the number of nodes in tree below this node, including leaf nodes.
@@ -64,6 +72,12 @@ sealed abstract class Node extends Serializable {
* @param id Node ID using old format IDs
*/
private[ml] def toOld(id: Int): OldNode
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int
}
private[ml] object Node {
@@ -75,7 +89,8 @@ private[ml] object Node {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
- new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+ new LeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
@@ -85,7 +100,7 @@ private[ml] object Node {
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
- split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}
@@ -99,11 +114,13 @@ private[ml] object Node {
@DeveloperApi
final class LeafNode private[ml] (
override val prediction: Double,
- override val impurity: Double) extends Node {
+ override val impurity: Double,
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
- override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+ override def toString: String =
+ s"LeafNode(prediction = $prediction, impurity = $impurity)"
- override private[ml] def predict(features: Vector): Double = prediction
+ override private[ml] def predictImpl(features: Vector): LeafNode = this
override private[tree] def numDescendants: Int = 0
@@ -115,10 +132,11 @@ final class LeafNode private[ml] (
override private[tree] def subtreeDepth: Int = 0
override private[ml] def toOld(id: Int): OldNode = {
- // NOTE: We do NOT store 'prob' in the new API currently.
- new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
- None, None, None, None)
+ new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
+ impurity, isLeaf = true, None, None, None, None)
}
+
+ override private[ml] def maxSplitFeatureIndex(): Int = -1
}
/**
@@ -139,17 +157,18 @@ final class InternalNode private[ml] (
val gain: Double,
val leftChild: Node,
val rightChild: Node,
- val split: Split) extends Node {
+ val split: Split,
+ override private[ml] val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
}
- override private[ml] def predict(features: Vector): Double = {
+ override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) {
- leftChild.predict(features)
+ leftChild.predictImpl(features)
} else {
- rightChild.predict(features)
+ rightChild.predictImpl(features)
}
}
@@ -172,14 +191,18 @@ final class InternalNode private[ml] (
override private[ml] def toOld(id: Int): OldNode = {
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ " since the old API does not support deep trees.")
- // NOTE: We do NOT store 'prob' in the new API currently.
- new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
- Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
+ isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
new OldPredict(leftChild.prediction, prob = 0.0),
new OldPredict(rightChild.prediction, prob = 0.0))))
}
+
+ override private[ml] def maxSplitFeatureIndex(): Int = {
+ math.max(split.featureIndex,
+ math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex()))
+ }
}
private object InternalNode {
@@ -223,36 +246,36 @@ private object InternalNode {
*
* @param id We currently use the same indexing as the old implementation in
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
- * @param predictionStats Predicted label + class probability (for classification).
- * We will later modify this to store aggregate statistics for labels
- * to provide all class probabilities (for classification) and maybe a
- * distribution (for regression).
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
* so that we do not need to consider splitting it further.
- * @param stats Old structure for storing stats about information gain, prediction, etc.
- * This is legacy and will be modified in the future.
+ * @param stats Impurity statistics for this node.
*/
private[tree] class LearningNode(
var id: Int,
- var predictionStats: OldPredict,
- var impurity: Double,
var leftChild: Option[LearningNode],
var rightChild: Option[LearningNode],
var split: Option[Split],
var isLeaf: Boolean,
- var stats: Option[OldInformationGainStats]) extends Serializable {
+ var stats: ImpurityStats) extends Serializable {
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode: Node = {
if (leftChild.nonEmpty) {
- assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
+ assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
- new InternalNode(predictionStats.predict, impurity, stats.get.gain,
- leftChild.get.toNode, rightChild.get.toNode, split.get)
+ new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
+ leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
} else {
- new LeafNode(predictionStats.predict, impurity)
+ if (stats.valid) {
+ new LeafNode(stats.impurityCalculator.predict, stats.impurity,
+ stats.impurityCalculator)
+ } else {
+ // Here we want to keep same behavior with the old mllib.DecisionTreeModel
+ new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
+ }
+
}
}
@@ -263,16 +286,14 @@ private[tree] object LearningNode {
/** Create a node with some of its fields set. */
def apply(
id: Int,
- predictionStats: OldPredict,
- impurity: Double,
- isLeaf: Boolean): LearningNode = {
- new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
+ isLeaf: Boolean,
+ stats: ImpurityStats): LearningNode = {
+ new LearningNode(id, None, None, None, false, stats)
}
/** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = {
- new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
- None, None, None, false, None)
+ new LearningNode(nodeIndex, None, None, None, false, null)
}
// The below indexing methods were copied from spark.mllib.tree.model.Node
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 15b56bd844bad..4ac51a475474a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,14 +26,16 @@ import org.apache.spark.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
+import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -180,13 +182,17 @@ private[ml] object RandomForest extends Logging {
parentUID match {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
- topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
+ }
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
- topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
+ }
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
}
@@ -549,9 +555,9 @@ private[ml] object RandomForest extends Logging {
}
// find best split for each node
- val (split: Split, stats: InformationGainStats, predict: Predict) =
+ val (split: Split, stats: ImpurityStats) =
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
- (nodeIndex, (split, stats, predict))
+ (nodeIndex, (split, stats))
}.collectAsMap()
timer.stop("chooseSplits")
@@ -568,17 +574,15 @@ private[ml] object RandomForest extends Logging {
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val (split: Split, stats: InformationGainStats, predict: Predict) =
+ val (split: Split, stats: ImpurityStats) =
nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
val isLeaf =
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
- node.predictionStats = predict
node.isLeaf = isLeaf
- node.stats = Some(stats)
- node.impurity = stats.impurity
+ node.stats = stats
logDebug("Node = " + node)
if (!isLeaf) {
@@ -587,9 +591,9 @@ private[ml] object RandomForest extends Logging {
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
- stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
- stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater(
@@ -621,28 +625,44 @@ private[ml] object RandomForest extends Logging {
}
/**
- * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+ * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
- * @return information gain and statistics for split
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
*/
- private def calculateGainForSplit(
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata,
- impurity: Double): InformationGainStats = {
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
+ val totalCount = leftCount + rightCount
+
// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
- return InformationGainStats.invalidInformationGainStats
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
- val totalCount = leftCount + rightCount
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -654,39 +674,11 @@ private[ml] object RandomForest extends Logging {
// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
- return InformationGainStats.invalidInformationGainStats
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
- // calculate left and right predict
- val leftPredict = calculatePredict(leftImpurityCalculator)
- val rightPredict = calculatePredict(rightImpurityCalculator)
-
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
- leftPredict, rightPredict)
- }
-
- private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
- val predict = impurityCalculator.predict
- val prob = impurityCalculator.prob(predict)
- new Predict(predict, prob)
- }
-
- /**
- * Calculate predict value for current node, given stats of any split.
- * Note that this function is called only once for each node.
- * @param leftImpurityCalculator left node aggregates for a split
- * @param rightImpurityCalculator right node aggregates for a split
- * @return predict value and impurity for current node
- */
- private def calculatePredictImpurity(
- leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
- val predict = calculatePredict(parentNodeAgg)
- val impurity = parentNodeAgg.calculate()
-
- (predict, impurity)
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
}
/**
@@ -698,14 +690,14 @@ private[ml] object RandomForest extends Logging {
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
- node: LearningNode): (Split, InformationGainStats, Predict) = {
+ node: LearningNode): (Split, ImpurityStats) = {
- // Calculate prediction and impurity if current node is top node
+ // Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
- var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
- None
+ var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+ null
} else {
- Some((node.predictionStats, node.impurity))
+ node.stats
}
// For each (feature, split), calculate the gain, and select the best (feature, split).
@@ -734,11 +726,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
- predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
- (splitIdx, gainStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
@@ -750,11 +740,9 @@ private[ml] object RandomForest extends Logging {
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
- predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
- (splitIndex, gainStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
@@ -825,11 +813,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
- predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
- (splitIndex, gainStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
@@ -839,7 +825,7 @@ private[ml] object RandomForest extends Logging {
}
}.maxBy(_._2.gain)
- (bestSplit, bestSplitStats, predictionAndImpurity.get._1)
+ (bestSplit, bestSplitStats)
}
/**
@@ -1129,4 +1115,94 @@ private[ml] object RandomForest extends Logging {
}
}
+ /**
+ * Given a Random Forest model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree based on total number of training instances used
+ * to build tree.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * Note: This should not be used with Gradient-Boosted Trees. It only makes sense for
+ * independently trained trees.
+ * @param trees Unweighted forest of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing RandomForest feature" +
+ s" importance: No splits in forest, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ private[impl] def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ * @param map Map with non-negative values.
+ */
+ private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 22873909c33fa..b77191156f68f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -53,6 +53,12 @@ private[ml] trait DecisionTreeModel {
val header = toString + "\n"
header + rootNode.subtreeToString(2)
}
+
+ /**
+ * Trace down the tree, and return the largest feature index used in any split.
+ * @return Max feature index used in a split, or -1 if there are no splits (single leaf node).
+ */
+ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index a0c5238d966bf..dbd8d31571d2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,9 +17,10 @@
package org.apache.spark.ml.tree
+import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
@@ -162,7 +163,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
oldAlgo: OldAlgo.Algo,
oldImpurity: OldImpurity,
subsamplingRate: Double): OldStrategy = {
- val strategy = OldStrategy.defaultStategy(oldAlgo)
+ val strategy = OldStrategy.defaultStrategy(oldAlgo)
strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index f979319cc4b58..4792eb0f0a288 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -160,6 +160,6 @@ class CrossValidatorModel private[ml] (
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
avgMetrics.clone())
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
index ddd34a54503a6..bd213e7362e94 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala
@@ -19,11 +19,19 @@ package org.apache.spark.ml.util
import java.util.UUID
+import org.apache.spark.annotation.DeveloperApi
+
/**
+ * :: DeveloperApi ::
+ *
* Trait for an object with an immutable unique ID that identifies itself and its derivatives.
+ *
+ * WARNING: There have not yet been final discussions on this API, so it may be broken in future
+ * releases.
*/
-private[spark] trait Identifiable {
+@DeveloperApi
+trait Identifiable {
/**
* An immutable unique ID for the object and its derivatives.
@@ -33,7 +41,11 @@ private[spark] trait Identifiable {
override def toString: String = uid
}
-private[spark] object Identifiable {
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+object Identifiable {
/**
* Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
index 2a1db90f2ca2b..fcb517b5f735e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap
import org.apache.spark.ml.attribute._
+import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types.StructField
@@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
}
}
+ /**
+ * Takes a Vector column and a list of feature names, and returns the corresponding list of
+ * feature indices in the column, in order.
+ * @param col Vector column which must have feature names specified via attributes
+ * @param names List of feature names
+ */
+ def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
+ require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col"
+ + s" to be Vector type, but it was type ${col.dataType} instead.")
+ val inputAttr = AttributeGroup.fromStructField(col)
+ names.map { name =>
+ require(inputAttr.hasAttr(name),
+ s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
+ inputAttr.getAttr(name).index.get
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 6f080d32bbf4d..f585aacd452e0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -37,6 +37,7 @@ import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.linalg.distributed._
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
@@ -54,7 +55,7 @@ import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomFo
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -1096,6 +1097,81 @@ private[python] class PythonMLLibAPI extends Serializable {
Statistics.kolmogorovSmirnovTest(data, distName, paramsSeq: _*)
}
+ /**
+ * Wrapper around RowMatrix constructor.
+ */
+ def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = {
+ new RowMatrix(rows.rdd, numRows, numCols)
+ }
+
+ /**
+ * Wrapper around IndexedRowMatrix constructor.
+ */
+ def createIndexedRowMatrix(rows: DataFrame, numRows: Long, numCols: Int): IndexedRowMatrix = {
+ // We use DataFrames for serialization of IndexedRows from Python,
+ // so map each Row in the DataFrame back to an IndexedRow.
+ val indexedRows = rows.map {
+ case Row(index: Long, vector: Vector) => IndexedRow(index, vector)
+ }
+ new IndexedRowMatrix(indexedRows, numRows, numCols)
+ }
+
+ /**
+ * Wrapper around CoordinateMatrix constructor.
+ */
+ def createCoordinateMatrix(rows: DataFrame, numRows: Long, numCols: Long): CoordinateMatrix = {
+ // We use DataFrames for serialization of MatrixEntry entries from
+ // Python, so map each Row in the DataFrame back to a MatrixEntry.
+ val entries = rows.map {
+ case Row(i: Long, j: Long, value: Double) => MatrixEntry(i, j, value)
+ }
+ new CoordinateMatrix(entries, numRows, numCols)
+ }
+
+ /**
+ * Wrapper around BlockMatrix constructor.
+ */
+ def createBlockMatrix(blocks: DataFrame, rowsPerBlock: Int, colsPerBlock: Int,
+ numRows: Long, numCols: Long): BlockMatrix = {
+ // We use DataFrames for serialization of sub-matrix blocks from
+ // Python, so map each Row in the DataFrame back to a
+ // ((blockRowIndex, blockColIndex), sub-matrix) tuple.
+ val blockTuples = blocks.map {
+ case Row(Row(blockRowIndex: Long, blockColIndex: Long), subMatrix: Matrix) =>
+ ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix)
+ }
+ new BlockMatrix(blockTuples, rowsPerBlock, colsPerBlock, numRows, numCols)
+ }
+
+ /**
+ * Return the rows of an IndexedRowMatrix.
+ */
+ def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = {
+ // We use DataFrames for serialization of IndexedRows to Python,
+ // so return a DataFrame.
+ val sqlContext = new SQLContext(indexedRowMatrix.rows.sparkContext)
+ sqlContext.createDataFrame(indexedRowMatrix.rows)
+ }
+
+ /**
+ * Return the entries of a CoordinateMatrix.
+ */
+ def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = {
+ // We use DataFrames for serialization of MatrixEntry entries to
+ // Python, so return a DataFrame.
+ val sqlContext = new SQLContext(coordinateMatrix.entries.sparkContext)
+ sqlContext.createDataFrame(coordinateMatrix.entries)
+ }
+
+ /**
+ * Return the sub-matrix blocks of a BlockMatrix.
+ */
+ def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = {
+ // We use DataFrames for serialization of sub-matrix blocks to
+ // Python, so return a DataFrame.
+ val sqlContext = new SQLContext(blockMatrix.blocks.sparkContext)
+ sqlContext.createDataFrame(blockMatrix.blocks)
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index cb807c8038101..76aeebd703d4e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -66,6 +66,12 @@ class GaussianMixtureModel(
responsibilityMatrix.map(r => r.indexOf(r.max))
}
+ /** Maps given point to its cluster index. */
+ def predict(point: Vector): Int = {
+ val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+ r.indexOf(r.max)
+ }
+
/** Java-friendly version of [[predict()]] */
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
@@ -83,6 +89,13 @@ class GaussianMixtureModel(
}
}
+ /**
+ * Given the input vector, return the membership values to all mixture components.
+ */
+ def predictSoft(point: Vector): Array[Double] = {
+ computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+ }
+
/**
* Compute the partial assignments for each vector
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 8ecb3df11d95e..96359024fa228 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -120,11 +120,11 @@ object KMeansModel extends Loader[KMeansModel] {
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
- val centriods = sqlContext.read.parquet(Loader.dataPath(path))
- Loader.checkSchema[Cluster](centriods.schema)
- val localCentriods = centriods.map(Cluster.apply).collect()
- assert(k == localCentriods.size)
- new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
+ val centroids = sqlContext.read.parquet(Loader.dataPath(path))
+ Loader.checkSchema[Cluster](centroids.schema)
+ val localCentroids = centroids.map(Cluster.apply).collect()
+ assert(k == localCentroids.size)
+ new KMeansModel(localCentroids.sortBy(_.id).map(_.point))
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index ab124e6d77c5e..0fc9b1ac4d716 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -79,7 +79,24 @@ class LDA private (
*
* This is the parameter to a Dirichlet distribution.
*/
- def getDocConcentration: Vector = this.docConcentration
+ def getAsymmetricDocConcentration: Vector = this.docConcentration
+
+ /**
+ * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+ * distributions over topics ("theta").
+ *
+ * This method assumes the Dirichlet distribution is symmetric and can be described by a single
+ * [[Double]] parameter. It should fail if docConcentration is asymmetric.
+ */
+ def getDocConcentration: Double = {
+ val parameter = docConcentration(0)
+ if (docConcentration.size == 1) {
+ parameter
+ } else {
+ require(docConcentration.toArray.forall(_ == parameter))
+ parameter
+ }
+ }
/**
* Concentration parameter (commonly named "alpha") for the prior placed on documents'
@@ -106,18 +123,22 @@ class LDA private (
* [[https://github.com/Blei-Lab/onlineldavb]].
*/
def setDocConcentration(docConcentration: Vector): this.type = {
+ require(docConcentration.size > 0, "docConcentration must have > 0 elements")
this.docConcentration = docConcentration
this
}
- /** Replicates Double to create a symmetric prior */
+ /** Replicates a [[Double]] docConcentration to create a symmetric prior. */
def setDocConcentration(docConcentration: Double): this.type = {
this.docConcentration = Vectors.dense(docConcentration)
this
}
+ /** Alias for [[getAsymmetricDocConcentration]] */
+ def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration
+
/** Alias for [[getDocConcentration]] */
- def getAlpha: Vector = getDocConcentration
+ def getAlpha: Double = getDocConcentration
/** Alias for [[setDocConcentration()]] */
def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 6cfad3fbbdb87..f31949f13a4cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum}
import breeze.numerics.{exp, lgamma}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
@@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -86,10 +86,6 @@ abstract class LDAModel private[clustering] extends Saveable {
/**
* Return the topics described by weighted terms.
*
- * This limits the number of terms per topic.
- * This is approximate; it may not return exactly the top-weighted terms for each topic.
- * To get a more precise set of top terms, increase maxTermsPerTopic.
- *
* @param maxTermsPerTopic Maximum number of terms to collect for each topic.
* @return Array over topics. Each topic is represented as a pair of matching arrays:
* (term indices, term weights in topic).
@@ -193,7 +189,8 @@ class LocalLDAModel private[clustering] (
val topics: Matrix,
override val docConcentration: Vector,
override val topicConcentration: Double,
- override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable {
+ override protected[clustering] val gammaShape: Double = 100)
+ extends LDAModel with Serializable {
override def k: Int = topics.numCols
@@ -217,24 +214,42 @@ class LocalLDAModel private[clustering] (
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
gammaShape)
}
- // TODO
- // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
+
+ // TODO: declare in LDAModel and override once implemented in DistributedLDAModel
+ /**
+ * Calculates a lower bound on the log likelihood of the entire corpus.
+ *
+ * See Equation (16) in original Online LDA paper.
+ *
+ * @param documents test corpus to use for calculating log likelihood
+ * @return variational lower bound on the log likelihood of the entire corpus
+ */
+ def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents,
+ docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k,
+ vocabSize)
+
+ /** Java-friendly version of [[logLikelihood]] */
+ def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = {
+ logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
+ }
/**
- * Calculate the log variational bound on perplexity. See Equation (16) in original Online
- * LDA paper.
+ * Calculate an upper bound bound on perplexity. (Lower is better.)
+ * See Equation (16) in original Online LDA paper.
+ *
* @param documents test corpus to use for calculating perplexity
- * @return the log perplexity per word
+ * @return Variational upper bound on log perplexity per token.
*/
def logPerplexity(documents: RDD[(Long, Vector)]): Double = {
- val corpusWords = documents
+ val corpusTokenCount = documents
.map { case (_, termCounts) => termCounts.toArray.sum }
.sum()
- val batchVariationalBound = bound(documents, docConcentration,
- topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize)
- val perWordBound = batchVariationalBound / corpusWords
+ -logLikelihood(documents) / corpusTokenCount
+ }
- perWordBound
+ /** Java-friendly version of [[logPerplexity]] */
+ def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = {
+ logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
}
/**
@@ -242,17 +257,20 @@ class LocalLDAModel private[clustering] (
* log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)]
* This bound is derived by decomposing the LDA model to:
* log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p)
- * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper.
+ * and noting that the KL-divergence D(q|p) >= 0.
+ *
+ * See Equation (16) in original Online LDA paper, as well as Appendix A.3 in the JMLR version of
+ * the original LDA paper.
* @param documents a subset of the test corpus
* @param alpha document-topic Dirichlet prior parameters
- * @param eta topic-word Dirichlet prior parameters
+ * @param eta topic-word Dirichlet prior parameter
* @param lambda parameters for variational q(beta | lambda) topic-word distributions
* @param gammaShape shape parameter for random initialization of variational q(theta | gamma)
* topic mixture distributions
* @param k number of topics
* @param vocabSize number of unique terms in the entire test corpus
*/
- private def bound(
+ private def logLikelihoodBound(
documents: RDD[(Long, Vector)],
alpha: Vector,
eta: Double,
@@ -264,33 +282,38 @@ class LocalLDAModel private[clustering] (
// transpose because dirichletExpectation normalizes by row and we need to normalize
// by topic (columns of lambda)
val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
+ val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)
+
+ // Sum bound components for each document:
+ // component for prob(tokens) + component for prob(document-topic distribution)
+ val corpusPart =
+ documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
+ val localElogbeta = ElogbetaBc.value
+ var docBound = 0.0D
+ val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
+ val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
+
+ // E[log p(doc | theta, beta)]
+ termCounts.foreachActive { case (idx, count) =>
+ docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t)
+ }
+ // E[log p(theta | alpha) - log q(theta | gamma)]
+ docBound += sum((brzAlpha - gammad) :* Elogthetad)
+ docBound += sum(lgamma(gammad) - lgamma(brzAlpha))
+ docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
- var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) =>
- var docScore = 0.0D
- val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, exp(Elogbeta), brzAlpha, gammaShape, k)
- val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
-
- // E[log p(doc | theta, beta)]
- termCounts.foreachActive { case (idx, count) =>
- docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t)
- }
- // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector
- docScore += sum((brzAlpha - gammad) :* Elogthetad)
- docScore += sum(lgamma(gammad) - lgamma(brzAlpha))
- docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
-
- docScore
- }.sum()
-
- // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar
- score += sum((eta - lambda) :* Elogbeta)
- score += sum(lgamma(lambda) - lgamma(eta))
+ docBound
+ }.sum()
+ // Bound component for prob(topic-term distributions):
+ // E[log p(beta | eta) - log q(beta | lambda)]
val sumEta = eta * vocabSize
- score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
+ val topicsPart = sum((eta - lambda) :* Elogbeta) +
+ sum(lgamma(lambda) - lgamma(eta)) +
+ sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
- score
+ corpusPart + topicsPart
}
/**
@@ -308,6 +331,7 @@ class LocalLDAModel private[clustering] (
// Double transpose because dirichletExpectation normalizes by row and we need to normalize
// by topic (columns of lambda)
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t)
+ val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta)
val docConcentrationBrz = this.docConcentration.toBreeze
val gammaShape = this.gammaShape
val k = this.k
@@ -318,7 +342,7 @@ class LocalLDAModel private[clustering] (
} else {
val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference(
termCounts,
- expElogbeta,
+ expElogbetaBc.value,
docConcentrationBrz,
gammaShape,
k)
@@ -327,8 +351,14 @@ class LocalLDAModel private[clustering] (
}
}
-}
+ /** Java-friendly version of [[topicDistributions]] */
+ def topicDistributions(
+ documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = {
+ val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
+ JavaPairRDD.fromRDD(distributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
+ }
+}
@Experimental
object LocalLDAModel extends Loader[LocalLDAModel] {
@@ -441,8 +471,9 @@ class DistributedLDAModel private[clustering] (
val vocabSize: Int,
override val docConcentration: Vector,
override val topicConcentration: Double,
- override protected[clustering] val gammaShape: Double,
- private[spark] val iterationTimes: Array[Double]) extends LDAModel {
+ private[spark] val iterationTimes: Array[Double],
+ override protected[clustering] val gammaShape: Double = 100)
+ extends LDAModel {
import LDA._
@@ -510,6 +541,40 @@ class DistributedLDAModel private[clustering] (
}
}
+ /**
+ * Return the top documents for each topic
+ *
+ * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic.
+ * @return Array over topics. Each element represent as a pair of matching arrays:
+ * (IDs for the documents, weights of the topic in these documents).
+ * For each topic, documents are sorted in order of decreasing topic weights.
+ */
+ def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = {
+ val numTopics = k
+ val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] =
+ topicDistributions.mapPartitions { docVertices =>
+ // For this partition, collect the most common docs for each topic in queues:
+ // queues(topic) = queue of (doc topic, doc ID).
+ val queues =
+ Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic))
+ for ((docId, docTopics) <- docVertices) {
+ var topic = 0
+ while (topic < numTopics) {
+ queues(topic) += (docTopics(topic) -> docId)
+ topic += 1
+ }
+ }
+ Iterator(queues)
+ }.treeReduce { (q1, q2) =>
+ q1.zip(q2).foreach { case (a, b) => a ++= b }
+ q1
+ }
+ topicsInQueues.map { q =>
+ val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip
+ (docs.toArray, docTopics.toArray)
+ }
+ }
+
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
@@ -591,6 +656,30 @@ class DistributedLDAModel private[clustering] (
JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]])
}
+ /**
+ * For each document, return the top k weighted topics for that document and their weights.
+ * @return RDD of (doc ID, topic indices, topic weights)
+ */
+ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
+ graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
+ val topIndices = argtopk(topicCounts, k)
+ val sumCounts = sum(topicCounts)
+ val weights = if (sumCounts != 0) {
+ topicCounts(topIndices) / sumCounts
+ } else {
+ topicCounts(topIndices)
+ }
+ (docID.toLong, topIndices.toArray, weights.toArray)
+ }
+ }
+
+ /** Java-friendly version of [[topTopicsPerDocument]] */
+ def javaTopTopicsPerDocument(
+ k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[java.lang.Double])] = {
+ val topics = topTopicsPerDocument(k)
+ topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[java.lang.Double])]].toJavaRDD()
+ }
+
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
@@ -691,7 +780,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] {
val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
- docConcentration, topicConcentration, gammaShape, iterationTimes)
+ docConcentration, topicConcentration, iterationTimes, gammaShape)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 9dbec41efeada..a0008f9c99ad7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering
import java.util.Random
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum}
-import breeze.numerics.{abs, exp}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum}
+import breeze.numerics.{trigamma, abs, exp}
import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.DeveloperApi
@@ -95,10 +95,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
* Compute bipartite term/doc graph.
*/
override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = {
- val docConcentration = lda.getDocConcentration(0)
- require({
- lda.getDocConcentration.toArray.forall(_ == docConcentration)
- }, "EMLDAOptimizer currently only supports symmetric document-topic priors")
+ // EMLDAOptimizer currently only supports symmetric document-topic priors
+ val docConcentration = lda.getDocConcentration
val topicConcentration = lda.getTopicConcentration
val k = lda.getK
@@ -144,6 +142,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
this.checkpointInterval = lda.getCheckpointInterval
this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
checkpointInterval, graph.vertices.sparkContext)
+ this.graphCheckpointer.update(this.graph)
this.globalTopicTotals = computeGlobalTopicTotals()
this
}
@@ -208,11 +207,11 @@ final class EMLDAOptimizer extends LDAOptimizer {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
this.graphCheckpointer.deleteAllCheckpoints()
- // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal
- // conversion
+ // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
+ // LDAModel.toLocal conversion
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
- 100, iterationTimes)
+ iterationTimes)
}
}
@@ -238,22 +237,26 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
/** alias for docConcentration */
private var alpha: Vector = Vectors.dense(0)
- /** (private[clustering] for debugging) Get docConcentration */
+ /** (for debugging) Get docConcentration */
private[clustering] def getAlpha: Vector = alpha
/** alias for topicConcentration */
private var eta: Double = 0
- /** (private[clustering] for debugging) Get topicConcentration */
+ /** (for debugging) Get topicConcentration */
private[clustering] def getEta: Double = eta
private var randomGenerator: java.util.Random = null
+ /** (for debugging) Whether to sample mini-batches with replacement. (default = true) */
+ private var sampleWithReplacement: Boolean = true
+
// Online LDA specific parameters
// Learning rate is: (tau0 + t)^{-kappa}
private var tau0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.05
+ private var optimizeAlpha: Boolean = false
// internal data structure
private var docs: RDD[(Long, Vector)] = null
@@ -261,7 +264,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
/** Dirichlet parameter for the posterior over topics */
private var lambda: BDM[Double] = null
- /** (private[clustering] for debugging) Get parameter for topics */
+ /** (for debugging) Get parameter for topics */
private[clustering] def getLambda: BDM[Double] = lambda
/** Current iteration (count of invocations of [[next()]]) */
@@ -324,7 +327,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
/**
- * (private[clustering])
+ * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution)
+ * will be optimized during training.
+ */
+ def getOptimzeAlpha: Boolean = this.optimizeAlpha
+
+ /**
+ * Sets whether to optimize alpha parameter during training.
+ *
+ * Default: false
+ */
+ def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = {
+ this.optimizeAlpha = optimizeAlpha
+ this
+ }
+
+ /**
* Set the Dirichlet parameter for the posterior over topics.
* This is only used for testing now. In the future, it can help support training stop/resume.
*/
@@ -334,7 +352,6 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
/**
- * (private[clustering])
* Used for random initialization of the variational parameters.
* Larger value produces values closer to 1.0.
* This is only used for testing currently.
@@ -344,24 +361,35 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
this
}
+ /**
+ * Sets whether to sample mini-batches with or without replacement. (default = true)
+ * This is only used for testing currently.
+ */
+ private[clustering] def setSampleWithReplacement(replace: Boolean): this.type = {
+ this.sampleWithReplacement = replace
+ this
+ }
+
override private[clustering] def initialize(
docs: RDD[(Long, Vector)],
lda: LDA): OnlineLDAOptimizer = {
this.k = lda.getK
this.corpusSize = docs.count()
this.vocabSize = docs.first()._2.size
- this.alpha = if (lda.getDocConcentration.size == 1) {
- if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
+ this.alpha = if (lda.getAsymmetricDocConcentration.size == 1) {
+ if (lda.getAsymmetricDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k))
else {
- require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha")
- Vectors.dense(Array.fill(k)(lda.getDocConcentration(0)))
+ require(lda.getAsymmetricDocConcentration(0) >= 0,
+ s"all entries in alpha must be >=0, got: $alpha")
+ Vectors.dense(Array.fill(k)(lda.getAsymmetricDocConcentration(0)))
}
} else {
- require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha")
- lda.getDocConcentration.foreachActive { case (_, x) =>
+ require(lda.getAsymmetricDocConcentration.size == k,
+ s"alpha must have length k, got: $alpha")
+ lda.getAsymmetricDocConcentration.foreachActive { case (_, x) =>
require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha")
}
- lda.getDocConcentration
+ lda.getAsymmetricDocConcentration
}
this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration
this.randomGenerator = new Random(lda.getSeed)
@@ -375,7 +403,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
override private[clustering] def next(): OnlineLDAOptimizer = {
- val batch = docs.sample(withReplacement = true, miniBatchFraction, randomGenerator.nextLong())
+ val batch = docs.sample(withReplacement = sampleWithReplacement, miniBatchFraction,
+ randomGenerator.nextLong())
if (batch.isEmpty()) return this
submitMiniBatch(batch)
}
@@ -390,6 +419,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
val k = this.k
val vocabSize = this.vocabSize
val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t
+ val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta)
val alpha = this.alpha.toBreeze
val gammaShape = this.gammaShape
@@ -404,19 +434,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
case v: SparseVector => v.indices.toList
}
val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, expElogbeta, alpha, gammaShape, k)
+ termCounts, expElogbetaBc.value, alpha, gammaShape, k)
stat(::, ids) := stat(::, ids).toDenseMatrix + sstats
gammaPart = gammad :: gammaPart
}
Iterator((stat, gammaPart))
}
val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+ expElogbetaBc.unpersist()
val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count
updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
+ if (optimizeAlpha) updateAlpha(gammat)
this
}
@@ -432,13 +464,39 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta)
}
- /** Calculates learning rate rho, which decays as a function of [[iteration]] */
+ /**
+ * Update alpha based on `gammat`, the inferred topic distributions for documents in the
+ * current mini-batch. Uses Newton-Rhapson method.
+ * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters
+ * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf)
+ */
+ private def updateAlpha(gammat: BDM[Double]): Unit = {
+ val weight = rho()
+ val N = gammat.rows.toDouble
+ val alpha = this.alpha.toBreeze.toDenseVector
+ val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N
+ val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector)
+
+ val c = N * trigamma(sum(alpha))
+ val q = -N * trigamma(alpha)
+ val b = sum(gradf / q) / (1D / c + sum(1D / q))
+
+ val dalpha = -(gradf - b) / q
+
+ if (all((weight * dalpha + alpha) :> 0D)) {
+ alpha :+= weight * dalpha
+ this.alpha = Vectors.dense(alpha.toArray)
+ }
+ }
+
+
+ /** Calculate learning rate rho for the current [[iteration]]. */
private def rho(): Double = {
math.pow(getTau0 + this.iteration, -getKappa)
}
/**
- * Get a random matrix to initialize lambda
+ * Get a random matrix to initialize lambda.
*/
private def getGammaMatrix(row: Int, col: Int): BDM[Double] = {
val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(
@@ -484,21 +542,22 @@ private[clustering] object OnlineLDAOptimizer {
val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K
val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K
- val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
- var meanchange = 1D
+ val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
+ var meanGammaChange = 1D
val ctsVector = new BDV[Double](cts) // ids
// Iterate between gamma and phi until convergence
- while (meanchange > 1e-3) {
+ while (meanGammaChange > 1e-3) {
val lastgamma = gammad.copy
// K K * ids ids
- gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha
+ gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha
expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
- phinorm := expElogbetad * expElogthetad :+ 1e-100
- meanchange = sum(abs(gammad - lastgamma)) / k
+ // TODO: Keep more values in log space, and only exponentiate when needed.
+ phiNorm := expElogbetad * expElogthetad :+ 1e-100
+ meanGammaChange = sum(abs(gammad - lastgamma)) / k
}
- val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix
+ val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix
(gammad, sstatsd)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
index f7e5ce1665fe6..a9ba7b60bad08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
@@ -22,7 +22,7 @@ import breeze.numerics._
/**
* Utility methods for LDA.
*/
-object LDAUtils {
+private[clustering] object LDAUtils {
/**
* Log Sum Exp with overflow protection using the identity:
* For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index c1d1a224817e8..486741edd6f5a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.DataFrame
* of bins may not exactly equal numBins. The last bin in each partition may
* be smaller as a result, meaning there may be an extra sample at
* partition boundaries.
+ * @since 1.3.0
*/
@Experimental
class BinaryClassificationMetrics(
@@ -51,6 +52,7 @@ class BinaryClassificationMetrics(
/**
* Defaults `numBins` to 0.
+ * @since 1.0.0
*/
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
@@ -61,12 +63,18 @@ class BinaryClassificationMetrics(
private[mllib] def this(scoreAndLabels: DataFrame) =
this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
- /** Unpersist intermediate RDDs used in the computation. */
+ /**
+ * Unpersist intermediate RDDs used in the computation.
+ * @since 1.0.0
+ */
def unpersist() {
cumulativeCounts.unpersist()
}
- /** Returns thresholds in descending order. */
+ /**
+ * Returns thresholds in descending order.
+ * @since 1.0.0
+ */
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
/**
@@ -74,6 +82,7 @@ class BinaryClassificationMetrics(
* which is an RDD of (false positive rate, true positive rate)
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+ * @since 1.0.0
*/
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
@@ -85,6 +94,7 @@ class BinaryClassificationMetrics(
/**
* Computes the area under the receiver operating characteristic (ROC) curve.
+ * @since 1.0.0
*/
def areaUnderROC(): Double = AreaUnderCurve.of(roc())
@@ -92,6 +102,7 @@ class BinaryClassificationMetrics(
* Returns the precision-recall curve, which is an RDD of (recall, precision),
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
* @see http://en.wikipedia.org/wiki/Precision_and_recall
+ * @since 1.0.0
*/
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
@@ -102,6 +113,7 @@ class BinaryClassificationMetrics(
/**
* Computes the area under the precision-recall curve.
+ * @since 1.0.0
*/
def areaUnderPR(): Double = AreaUnderCurve.of(pr())
@@ -110,16 +122,26 @@ class BinaryClassificationMetrics(
* @param beta the beta factor in F-Measure computation.
* @return an RDD of (threshold, F-Measure) pairs.
* @see http://en.wikipedia.org/wiki/F1_score
+ * @since 1.0.0
*/
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
- /** Returns the (threshold, F-Measure) curve with beta = 1.0. */
+ /**
+ * Returns the (threshold, F-Measure) curve with beta = 1.0.
+ * @since 1.0.0
+ */
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
- /** Returns the (threshold, precision) curve. */
+ /**
+ * Returns the (threshold, precision) curve.
+ * @since 1.0.0
+ */
def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)
- /** Returns the (threshold, recall) curve. */
+ /**
+ * Returns the (threshold, recall) curve.
+ * @since 1.0.0
+ */
def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)
private lazy val (
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 4628dc5690913..dddfa3ea5b800 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.DataFrame
* Evaluator for multiclass classification.
*
* @param predictionAndLabels an RDD of (prediction, label) pairs.
+ * @since 1.1.0
*/
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
@@ -64,6 +65,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
* predicted classes are in columns,
* they are ordered by class label ascending,
* as in "labels"
+ * @since 1.1.0
*/
def confusionMatrix: Matrix = {
val n = labels.size
@@ -83,12 +85,14 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
/**
* Returns true positive rate for a given label (category)
* @param label the label.
+ * @since 1.1.0
*/
def truePositiveRate(label: Double): Double = recall(label)
/**
* Returns false positive rate for a given label (category)
* @param label the label.
+ * @since 1.1.0
*/
def falsePositiveRate(label: Double): Double = {
val fp = fpByClass.getOrElse(label, 0)
@@ -98,6 +102,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
/**
* Returns precision for a given label (category)
* @param label the label.
+ * @since 1.1.0
*/
def precision(label: Double): Double = {
val tp = tpByClass(label)
@@ -108,6 +113,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
/**
* Returns recall for a given label (category)
* @param label the label.
+ * @since 1.1.0
*/
def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label)
@@ -115,6 +121,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
* Returns f-measure for a given label (category)
* @param label the label.
* @param beta the beta parameter.
+ * @since 1.1.0
*/
def fMeasure(label: Double, beta: Double): Double = {
val p = precision(label)
@@ -126,6 +133,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
/**
* Returns f1-measure for a given label (category)
* @param label the label.
+ * @since 1.1.0
*/
def fMeasure(label: Double): Double = fMeasure(label, 1.0)
@@ -179,6 +187,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
/**
* Returns weighted averaged f-measure
* @param beta the beta parameter.
+ * @since 1.1.0
*/
def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) =>
fMeasure(category, beta) * count.toDouble / labelCount
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
index bf6eb1d5bd2ab..77cb1e09bdbb5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.DataFrame
* Evaluator for multilabel classification.
* @param predictionAndLabels an RDD of (predictions, labels) pairs,
* both are non-null Arrays, each with unique elements.
+ * @since 1.2.0
*/
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
@@ -103,6 +104,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
/**
* Returns precision for a given label (category)
* @param label the label.
+ * @since 1.2.0
*/
def precision(label: Double): Double = {
val tp = tpPerClass(label)
@@ -113,6 +115,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
/**
* Returns recall for a given label (category)
* @param label the label.
+ * @since 1.2.0
*/
def recall(label: Double): Double = {
val tp = tpPerClass(label)
@@ -123,6 +126,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]
/**
* Returns f1-measure for a given label (category)
* @param label the label.
+ * @since 1.2.0
*/
def f1Measure(label: Double): Double = {
val p = precision(label)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 5b5a2a1450f7f..063fbed8cdeea 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
* Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance.
*
* @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
+ * @since 1.2.0
*/
@Experimental
class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
@@ -55,6 +56,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
*
* @param k the position to compute the truncated precision, must be positive
* @return the average precision at the first k ranking positions
+ * @since 1.2.0
*/
def precisionAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
@@ -124,6 +126,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
*
* @param k the position to compute the truncated ndcg, must be positive
* @return the average ndcg at the first k ranking positions
+ * @since 1.2.0
*/
def ndcgAt(k: Int): Double = {
require(k > 0, "ranking position k should be positive")
@@ -162,6 +165,7 @@ object RankingMetrics {
/**
* Creates a [[RankingMetrics]] instance (for Java users).
* @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
+ * @since 1.4.0
*/
def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
implicit val tag = JavaSparkContext.fakeClassTag[E]
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 408847afa800d..54dfd8c099494 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.DataFrame
* Evaluator for regression.
*
* @param predictionAndObservations an RDD of (prediction, observation) pairs.
+ * @since 1.2.0
*/
@Experimental
class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
@@ -66,6 +67,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
* Returns the variance explained by regression.
* explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n
* @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
+ * @since 1.2.0
*/
def explainedVariance: Double = {
SSreg / summary.count
@@ -74,6 +76,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
/**
* Returns the mean absolute error, which is a risk function corresponding to the
* expected value of the absolute error loss or l1-norm loss.
+ * @since 1.2.0
*/
def meanAbsoluteError: Double = {
summary.normL1(1) / summary.count
@@ -82,6 +85,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
/**
* Returns the mean squared error, which is a risk function corresponding to the
* expected value of the squared error loss or quadratic loss.
+ * @since 1.2.0
*/
def meanSquaredError: Double = {
SSerr / summary.count
@@ -90,6 +94,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
/**
* Returns the root mean squared error, which is defined as the square root of
* the mean squared error.
+ * @since 1.2.0
*/
def rootMeanSquaredError: Double = {
math.sqrt(this.meanSquaredError)
@@ -98,6 +103,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
/**
* Returns R^2^, the unadjusted coefficient of determination.
* @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ * @since 1.2.0
*/
def r2: Double = {
1 - SSerr / SStot
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 0ea792081086d..3ea10779a1837 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -22,73 +22,89 @@ import scala.collection.mutable
import org.apache.spark.Logging
/**
- * Calculate all patterns of a projected database in local.
+ * Calculate all patterns of a projected database in local mode.
+ *
+ * @param minCount minimal count for a frequent pattern
+ * @param maxPatternLength max pattern length for a frequent pattern
*/
-private[fpm] object LocalPrefixSpan extends Logging with Serializable {
+private[fpm] class LocalPrefixSpan(
+ val minCount: Long,
+ val maxPatternLength: Int) extends Logging with Serializable {
+ import PrefixSpan.Postfix
+ import LocalPrefixSpan.ReversedPrefix
/**
- * Calculate all patterns of a projected database.
- * @param minCount minimum count
- * @param maxPatternLength maximum pattern length
- * @param prefixes prefixes in reversed order
- * @param database the projected database
- * @return a set of sequential pattern pairs,
- * the key of pair is sequential pattern (a list of items in reversed order),
- * the value of pair is the pattern's count.
+ * Generates frequent patterns on the input array of postfixes.
+ * @param postfixes an array of postfixes
+ * @return an iterator of (frequent pattern, count)
*/
- def run(
- minCount: Long,
- maxPatternLength: Int,
- prefixes: List[Int],
- database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
- if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
- val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
- val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
- frequentItemAndCounts.iterator.flatMap { case (item, count) =>
- val newPrefixes = item :: prefixes
- val newProjected = project(filteredDatabase, item)
- Iterator.single((newPrefixes, count)) ++
- run(minCount, maxPatternLength, newPrefixes, newProjected)
+ def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = {
+ genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix, count) =>
+ (prefix.toSequence, count)
}
}
/**
- * Calculate suffix sequence immediately after the first occurrence of an item.
- * @param item item to get suffix after
- * @param sequence sequence to extract suffix from
- * @return suffix sequence
+ * Recursively generates frequent patterns.
+ * @param prefix current prefix
+ * @param postfixes projected postfixes w.r.t. the prefix
+ * @return an iterator of (prefix, count)
*/
- def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
- val index = sequence.indexOf(item)
- if (index == -1) {
- Array()
- } else {
- sequence.drop(index + 1)
+ private def genFreqPatterns(
+ prefix: ReversedPrefix,
+ postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = {
+ if (maxPatternLength == prefix.length || postfixes.length < minCount) {
+ return Iterator.empty
+ }
+ // find frequent items
+ val counts = mutable.Map.empty[Int, Long].withDefaultValue(0)
+ postfixes.foreach { postfix =>
+ postfix.genPrefixItems.foreach { case (x, _) =>
+ counts(x) += 1L
+ }
+ }
+ val freqItems = counts.toSeq.filter { case (_, count) =>
+ count >= minCount
+ }.sorted
+ // project and recursively call genFreqPatterns
+ freqItems.toIterator.flatMap { case (item, count) =>
+ val newPrefix = prefix :+ item
+ Iterator.single((newPrefix, count)) ++ {
+ val projected = postfixes.map(_.project(item)).filter(_.nonEmpty)
+ genFreqPatterns(newPrefix, projected)
+ }
}
}
+}
- def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
- database
- .map(getSuffix(prefix, _))
- .filter(_.nonEmpty)
- }
+private object LocalPrefixSpan {
/**
- * Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the minimum count for an item to be frequent
- * @param database database of sequences
- * @return freq item to count map
+ * Represents a prefix stored as a list in reversed order.
+ * @param items items in the prefix in reversed order
+ * @param length length of the prefix, not counting delimiters
*/
- private def getFreqItemAndCounts(
- minCount: Long,
- database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
- // TODO: use PrimitiveKeyOpenHashMap
- val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
- database.foreach { sequence =>
- sequence.distinct.foreach { item =>
- counts(item) += 1L
+ class ReversedPrefix private (val items: List[Int], val length: Int) extends Serializable {
+ /**
+ * Expands the prefix by one item.
+ */
+ def :+(item: Int): ReversedPrefix = {
+ require(item != 0)
+ if (item < 0) {
+ new ReversedPrefix(-item :: items, length + 1)
+ } else {
+ new ReversedPrefix(item :: 0 :: items, length + 1)
}
}
- counts.filter(_._2 >= minCount)
+
+ /**
+ * Converts this prefix to a sequence.
+ */
+ def toSequence: Array[Int] = (0 :: items).toArray.reverse
+ }
+
+ object ReversedPrefix {
+ /** An empty prefix. */
+ val empty: ReversedPrefix = new ReversedPrefix(List.empty, 0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index e6752332cdeeb..dc4ae1d0b69ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -17,25 +17,35 @@
package org.apache.spark.mllib.fpm
-import scala.collection.mutable.ArrayBuffer
+import java.{lang => jl, util => ju}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
- *
* :: Experimental ::
*
- * A parallel PrefixSpan algorithm to mine sequential pattern.
- * The PrefixSpan algorithm is described in
- * [[http://doi.org/10.1109/ICDE.2001.914830]].
+ * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]).
*
* @param minSupport the minimal support level of the sequential pattern, any pattern appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
- * less than maxPatternLength will be output
+ * less than maxPatternLength will be output
+ * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal
+ * storage format) allowed in a projected database before local
+ * processing. If a projected database exceeds this size, another
+ * iteration of distributed prefix growth is run.
*
* @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
* (Wikipedia)]]
@@ -43,32 +53,28 @@ import org.apache.spark.storage.StorageLevel
@Experimental
class PrefixSpan private (
private var minSupport: Double,
- private var maxPatternLength: Int) extends Logging with Serializable {
-
- /**
- * The maximum number of items allowed in a projected database before local processing. If a
- * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
- */
- // TODO: make configurable with a better default value, 10000 may be too small
- private val maxLocalProjDBSize: Long = 10000
+ private var maxPatternLength: Int,
+ private var maxLocalProjDBSize: Long) extends Logging with Serializable {
+ import PrefixSpan._
/**
* Constructs a default instance with default parameters
- * {minSupport: `0.1`, maxPatternLength: `10`}.
+ * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}.
*/
- def this() = this(0.1, 10)
+ def this() = this(0.1, 10, 32000000L)
/**
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
* frequent).
*/
- def getMinSupport: Double = this.minSupport
+ def getMinSupport: Double = minSupport
/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
- require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
+ require(minSupport >= 0 && minSupport <= 1,
+ s"The minimum support value must be in [0, 1], but got $minSupport.")
this.minSupport = minSupport
this
}
@@ -76,174 +82,471 @@ class PrefixSpan private (
/**
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
*/
- def getMaxPatternLength: Double = this.maxPatternLength
+ def getMaxPatternLength: Int = maxPatternLength
/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
// TODO: support unbounded pattern length when maxPatternLength = 0
- require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
+ require(maxPatternLength >= 1,
+ s"The maximum pattern length value must be greater than 0, but got $maxPatternLength.")
this.maxPatternLength = maxPatternLength
this
}
/**
- * Find the complete set of sequential patterns in the input sequences.
- * @param sequences input data set, contains a set of sequences,
- * a sequence is an ordered list of elements.
- * @return a set of sequential pattern pairs,
- * the key of pair is pattern (a list of elements),
- * the value of pair is the pattern's count.
+ * Gets the maximum number of items allowed in a projected database before local processing.
*/
- def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
- val sc = sequences.sparkContext
+ def getMaxLocalProjDBSize: Long = maxLocalProjDBSize
- if (sequences.getStorageLevel == StorageLevel.NONE) {
+ /**
+ * Sets the maximum number of items (including delimiters used in the internal storage format)
+ * allowed in a projected database before local processing (default: `32000000L`).
+ */
+ def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = {
+ require(maxLocalProjDBSize >= 0L,
+ s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize")
+ this.maxLocalProjDBSize = maxLocalProjDBSize
+ this
+ }
+
+ /**
+ * Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
+ * @param data sequences of itemsets.
+ * @return a [[PrefixSpanModel]] that contains the frequent patterns
+ */
+ def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
- // Convert min support to a min number of transactions for this dataset
- val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
-
- // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
- val freqItemCounts = sequences
- .flatMap(seq => seq.distinct.map(item => (item, 1L)))
- .reduceByKey(_ + _)
- .filter(_._2 >= minCount)
- .collect()
-
- // Pairs of (length 1 prefix, suffix consisting of frequent items)
- val itemSuffixPairs = {
- val freqItems = freqItemCounts.map(_._1).toSet
- sequences.flatMap { seq =>
- val filteredSeq = seq.filter(freqItems.contains(_))
- freqItems.flatMap { item =>
- val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
- candidateSuffix match {
- case suffix if !suffix.isEmpty => Some((List(item), suffix))
- case _ => None
+ val totalCount = data.count()
+ logInfo(s"number of sequences: $totalCount")
+ val minCount = math.ceil(minSupport * totalCount).toLong
+ logInfo(s"minimum count for a frequent pattern: $minCount")
+
+ // Find frequent items.
+ val freqItemAndCounts = data.flatMap { itemsets =>
+ val uniqItems = mutable.Set.empty[Item]
+ itemsets.foreach { _.foreach { item =>
+ uniqItems += item
+ }}
+ uniqItems.toIterator.map((_, 1L))
+ }.reduceByKey(_ + _)
+ .filter { case (_, count) =>
+ count >= minCount
+ }.collect()
+ val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1)
+ logInfo(s"number of frequent items: ${freqItems.length}")
+
+ // Keep only frequent items from input sequences and convert them to internal storage.
+ val itemToInt = freqItems.zipWithIndex.toMap
+ val dataInternalRepr = data.flatMap { itemsets =>
+ val allItems = mutable.ArrayBuilder.make[Int]
+ var containsFreqItems = false
+ allItems += 0
+ itemsets.foreach { itemsets =>
+ val items = mutable.ArrayBuilder.make[Int]
+ itemsets.foreach { item =>
+ if (itemToInt.contains(item)) {
+ items += itemToInt(item) + 1 // using 1-indexing in internal format
}
}
+ val result = items.result()
+ if (result.nonEmpty) {
+ containsFreqItems = true
+ allItems ++= result.sorted
+ }
+ allItems += 0
}
- }
+ if (containsFreqItems) {
+ Iterator.single(allItems.result())
+ } else {
+ Iterator.empty
+ }
+ }.persist(StorageLevel.MEMORY_AND_DISK)
- // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
- // frequent length-one prefixes)
- var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
-
- // Remaining work to be locally and distributively processed respectfully
- var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
-
- // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
- // projected database sizes <= `maxLocalProjDBSize`)
- while (pairsForDistributed.count() != 0) {
- val (nextPatternAndCounts, nextPrefixSuffixPairs) =
- extendPrefixes(minCount, pairsForDistributed)
- pairsForDistributed.unpersist()
- val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
- pairsForDistributed = largerPairsPart
- pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
- pairsForLocal ++= smallerPairsPart
- resultsAccumulator ++= nextPatternAndCounts.collect()
- }
+ val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize)
- // Process the small projected databases locally
- val remainingResults = getPatternsInLocal(
- minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
+ def toPublicRepr(pattern: Array[Int]): Array[Array[Item]] = {
+ val sequenceBuilder = mutable.ArrayBuilder.make[Array[Item]]
+ val itemsetBuilder = mutable.ArrayBuilder.make[Item]
+ val n = pattern.length
+ var i = 1
+ while (i < n) {
+ val x = pattern(i)
+ if (x == 0) {
+ sequenceBuilder += itemsetBuilder.result()
+ itemsetBuilder.clear()
+ } else {
+ itemsetBuilder += freqItems(x - 1) // using 1-indexing in internal format
+ }
+ i += 1
+ }
+ sequenceBuilder.result()
+ }
- (sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
- .map { case (pattern, count) => (pattern.toArray, count) }
+ val freqSequences = results.map { case (seq: Array[Int], count: Long) =>
+ new FreqSequence(toPublicRepr(seq), count)
+ }
+ new PrefixSpanModel(freqSequences)
}
-
/**
- * Partitions the prefix-suffix pairs by projected database size.
- * @param prefixSuffixPairs prefix (length n) and suffix pairs,
- * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
- * greater than [[maxLocalProjDBSize]]
+ * A Java-friendly version of [[run()]] that reads sequences from a [[JavaRDD]] and returns
+ * frequent sequences in a [[PrefixSpanModel]].
+ * @param data ordered sequences of itemsets stored as Java Iterable of Iterables
+ * @tparam Item item type
+ * @tparam Itemset itemset type, which is an Iterable of Items
+ * @tparam Sequence sequence type, which is an Iterable of Itemsets
+ * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns
*/
- private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
- : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
- val prefixToSuffixSize = prefixSuffixPairs
- .aggregateByKey(0)(
- seqOp = { case (count, suffix) => count + suffix.length },
- combOp = { _ + _ })
- val smallPrefixes = prefixToSuffixSize
- .filter(_._2 <= maxLocalProjDBSize)
- .keys
- .collect()
- .toSet
- val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
- val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
- (small.collect(), large)
+ def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]](
+ data: JavaRDD[Sequence]): PrefixSpanModel[Item] = {
+ implicit val tag = fakeClassTag[Item]
+ run(data.rdd.map(_.asScala.map(_.asScala.toArray).toArray))
}
+}
+
+@Experimental
+object PrefixSpan extends Logging {
+
/**
- * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
- * and remaining work.
- * @param minCount minimum count
- * @param prefixSuffixPairs prefix (length N) and suffix pairs,
- * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
- * prefix, corresponding suffix) pairs.
+ * Find the complete set of frequent sequential patterns in the input sequences.
+ * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int],
+ * where each itemset is represented by a contiguous sequence of distinct and ordered
+ * positive integers. We use 0 as the delimiter at itemset boundaries, including the
+ * first and the last position.
+ * @return an RDD of (frequent sequential pattern, count) pairs,
+ * @see [[Postfix]]
*/
- private def extendPrefixes(
+ private[fpm] def genFreqPatterns(
+ data: RDD[Array[Int]],
minCount: Long,
- prefixSuffixPairs: RDD[(List[Int], Array[Int])])
- : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
-
- // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
- // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
- val prefixItemPairAndCounts = prefixSuffixPairs
- .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
- .reduceByKey(_ + _)
- .filter(_._2 >= minCount)
-
- // Map from prefix to set of possible next items from suffix
- val prefixToNextItems = prefixItemPairAndCounts
- .keys
- .groupByKey()
- .mapValues(_.toSet)
- .collect()
- .toMap
-
-
- // Frequent patterns with length N+1 and their corresponding counts
- val extendedPrefixAndCounts = prefixItemPairAndCounts
- .map { case ((prefix, item), count) => (item :: prefix, count) }
-
- // Remaining work, all prefixes will have length N+1
- val extendedPrefixAndSuffix = prefixSuffixPairs
- .filter(x => prefixToNextItems.contains(x._1))
- .flatMap { case (prefix, suffix) =>
- val frequentNextItems = prefixToNextItems(prefix)
- val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
- frequentNextItems.flatMap { item =>
- LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
- case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
- case _ => None
+ maxPatternLength: Int,
+ maxLocalProjDBSize: Long): RDD[(Array[Int], Long)] = {
+ val sc = data.sparkContext
+
+ if (data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("Input data is not cached.")
+ }
+
+ val postfixes = data.map(items => new Postfix(items))
+
+ // Local frequent patterns (prefixes) and their counts.
+ val localFreqPatterns = mutable.ArrayBuffer.empty[(Array[Int], Long)]
+ // Prefixes whose projected databases are small.
+ val smallPrefixes = mutable.Map.empty[Int, Prefix]
+ val emptyPrefix = Prefix.empty
+ // Prefixes whose projected databases are large.
+ var largePrefixes = mutable.Map(emptyPrefix.id -> emptyPrefix)
+ while (largePrefixes.nonEmpty) {
+ val numLocalFreqPatterns = localFreqPatterns.length
+ logInfo(s"number of local frequent patterns: $numLocalFreqPatterns")
+ if (numLocalFreqPatterns > 1000000) {
+ logWarning(
+ s"""
+ | Collected $numLocalFreqPatterns local frequent patterns. You may want to consider:
+ | 1. increase minSupport,
+ | 2. decrease maxPatternLength,
+ | 3. increase maxLocalProjDBSize.
+ """.stripMargin)
+ }
+ logInfo(s"number of small prefixes: ${smallPrefixes.size}")
+ logInfo(s"number of large prefixes: ${largePrefixes.size}")
+ val largePrefixArray = largePrefixes.values.toArray
+ val freqPrefixes = postfixes.flatMap { postfix =>
+ largePrefixArray.flatMap { prefix =>
+ postfix.project(prefix).genPrefixItems.map { case (item, postfixSize) =>
+ ((prefix.id, item), (1L, postfixSize))
+ }
+ }
+ }.reduceByKey { case ((c0, s0), (c1, s1)) =>
+ (c0 + c1, s0 + s1)
+ }.filter { case (_, (c, _)) => c >= minCount }
+ .collect()
+ val newLargePrefixes = mutable.Map.empty[Int, Prefix]
+ freqPrefixes.foreach { case ((id, item), (count, projDBSize)) =>
+ val newPrefix = largePrefixes(id) :+ item
+ localFreqPatterns += ((newPrefix.items :+ 0, count))
+ if (newPrefix.length < maxPatternLength) {
+ if (projDBSize > maxLocalProjDBSize) {
+ newLargePrefixes += newPrefix.id -> newPrefix
+ } else {
+ smallPrefixes += newPrefix.id -> newPrefix
}
}
}
+ largePrefixes = newLargePrefixes
+ }
+
+ var freqPatterns = sc.parallelize(localFreqPatterns, 1)
- (extendedPrefixAndCounts, extendedPrefixAndSuffix)
+ val numSmallPrefixes = smallPrefixes.size
+ logInfo(s"number of small prefixes for local processing: $numSmallPrefixes")
+ if (numSmallPrefixes > 0) {
+ // Switch to local processing.
+ val bcSmallPrefixes = sc.broadcast(smallPrefixes)
+ val distributedFreqPattern = postfixes.flatMap { postfix =>
+ bcSmallPrefixes.value.values.map { prefix =>
+ (prefix.id, postfix.project(prefix).compressed)
+ }.filter(_._2.nonEmpty)
+ }.groupByKey().flatMap { case (id, projPostfixes) =>
+ val prefix = bcSmallPrefixes.value(id)
+ val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length)
+ // TODO: We collect projected postfixes into memory. We should also compare the performance
+ // TODO: of keeping them on shuffle files.
+ localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) =>
+ (prefix.items ++ pattern, count)
+ }
+ }
+ // Union local frequent patterns and distributed ones.
+ freqPatterns = freqPatterns ++ distributedFreqPattern
+ }
+
+ freqPatterns
}
/**
- * Calculate the patterns in local.
- * @param minCount the absolute minimum count
- * @param data prefixes and projected sequences data data
- * @return patterns
+ * Represents a prefix.
+ * @param items items in this prefix, using the internal format
+ * @param length length of this prefix, not counting 0
*/
- private def getPatternsInLocal(
- minCount: Long,
- data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
- data.flatMap {
- case (prefix, projDB) =>
- LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
- .map { case (pattern: List[Int], count: Long) =>
- (pattern.reverse, count)
+ private[fpm] class Prefix private (val items: Array[Int], val length: Int) extends Serializable {
+
+ /** A unique id for this prefix. */
+ val id: Int = Prefix.nextId
+
+ /** Expands this prefix by the input item. */
+ def :+(item: Int): Prefix = {
+ require(item != 0)
+ if (item < 0) {
+ new Prefix(items :+ -item, length + 1)
+ } else {
+ new Prefix(items ++ Array(0, item), length + 1)
+ }
+ }
+ }
+
+ private[fpm] object Prefix {
+ /** Internal counter to generate unique IDs. */
+ private val counter: AtomicInteger = new AtomicInteger(-1)
+
+ /** Gets the next unique ID. */
+ private def nextId: Int = counter.incrementAndGet()
+
+ /** An empty [[Prefix]] instance. */
+ val empty: Prefix = new Prefix(Array.empty, 0)
+ }
+
+ /**
+ * An internal representation of a postfix from some projection.
+ * We use one int array to store the items, which might also contains other items from the
+ * original sequence.
+ * Items are represented by positive integers, and items in each itemset must be distinct and
+ * ordered.
+ * we use 0 as the delimiter between itemsets.
+ * For example, a sequence `<(12)(31)1>` is represented by `[0, 1, 2, 0, 1, 3, 0, 1, 0]`.
+ * The postfix of this sequence w.r.t. to prefix `<1>` is `<(_2)(13)1>`.
+ * We may reuse the original items array `[0, 1, 2, 0, 1, 3, 0, 1, 0]` to represent the postfix,
+ * and mark the start index of the postfix, which is `2` in this example.
+ * So the active items in this postfix are `[2, 0, 1, 3, 0, 1, 0]`.
+ * We also remember the start indices of partial projections, the ones that split an itemset.
+ * For example, another possible partial projection w.r.t. `<1>` is `<(_3)1>`.
+ * We remember the start indices of partial projections, which is `[2, 5]` in this example.
+ * This data structure makes it easier to do projections.
+ *
+ * @param items a sequence stored as `Array[Int]` containing this postfix
+ * @param start the start index of this postfix in items
+ * @param partialStarts start indices of possible partial projections, strictly increasing
+ */
+ private[fpm] class Postfix(
+ val items: Array[Int],
+ val start: Int = 0,
+ val partialStarts: Array[Int] = Array.empty) extends Serializable {
+
+ require(items.last == 0, s"The last item in a postfix must be zero, but got ${items.last}.")
+ if (partialStarts.nonEmpty) {
+ require(partialStarts.head >= start,
+ "The first partial start cannot be smaller than the start index," +
+ s"but got partialStarts.head = ${partialStarts.head} < start = $start.")
+ }
+
+ /**
+ * Start index of the first full itemset contained in this postfix.
+ */
+ private[this] def fullStart: Int = {
+ var i = start
+ while (items(i) != 0) {
+ i += 1
+ }
+ i
+ }
+
+ /**
+ * Generates length-1 prefix items of this postfix with the corresponding postfix sizes.
+ * There are two types of prefix items:
+ * a) The item can be assembled to the last itemset of the prefix. For example,
+ * the postfix of `<(12)(123)>1` w.r.t. `<1>` is `<(_2)(123)1>`. The prefix items of this
+ * postfix can be assembled to `<1>` is `_2` and `_3`, resulting new prefixes `<(12)>` and
+ * `<(13)>`. We flip the sign in the output to indicate that this is a partial prefix item.
+ * b) The item can be appended to the prefix. Taking the same example above, the prefix items
+ * can be appended to `<1>` is `1`, `2`, and `3`, resulting new prefixes `<11>`, `<12>`,
+ * and `<13>`.
+ * @return an iterator of (prefix item, corresponding postfix size). If the item is negative, it
+ * indicates a partial prefix item, which should be assembled to the last itemset of the
+ * current prefix. Otherwise, the item should be appended to the current prefix.
+ */
+ def genPrefixItems: Iterator[(Int, Long)] = {
+ val n1 = items.length - 1
+ // For each unique item (subject to sign) in this sequence, we output exact one split.
+ // TODO: use PrimitiveKeyOpenHashMap
+ val prefixes = mutable.Map.empty[Int, Long]
+ // a) items that can be assembled to the last itemset of the prefix
+ partialStarts.foreach { start =>
+ var i = start
+ var x = -items(i)
+ while (x != 0) {
+ if (!prefixes.contains(x)) {
+ prefixes(x) = n1 - i
+ }
+ i += 1
+ x = -items(i)
+ }
+ }
+ // b) items that can be appended to the prefix
+ var i = fullStart
+ while (i < n1) {
+ val x = items(i)
+ if (x != 0 && !prefixes.contains(x)) {
+ prefixes(x) = n1 - i
+ }
+ i += 1
+ }
+ prefixes.toIterator
+ }
+
+ /** Tests whether this postfix is non-empty. */
+ def nonEmpty: Boolean = items.length > start + 1
+
+ /**
+ * Projects this postfix with respect to the input prefix item.
+ * @param prefix prefix item. If prefix is positive, we match items in any full itemset; if it
+ * is negative, we do partial projections.
+ * @return the projected postfix
+ */
+ def project(prefix: Int): Postfix = {
+ require(prefix != 0)
+ val n1 = items.length - 1
+ var matched = false
+ var newStart = n1
+ val newPartialStarts = mutable.ArrayBuilder.make[Int]
+ if (prefix < 0) {
+ // Search for partial projections.
+ val target = -prefix
+ partialStarts.foreach { start =>
+ var i = start
+ var x = items(i)
+ while (x != target && x != 0) {
+ i += 1
+ x = items(i)
+ }
+ if (x == target) {
+ i += 1
+ if (!matched) {
+ newStart = i
+ matched = true
+ }
+ if (items(i) != 0) {
+ newPartialStarts += i
+ }
+ }
}
+ } else {
+ // Search for items in full itemsets.
+ // Though the items are ordered in each itemsets, they should be small in practice.
+ // So a sequential scan is sufficient here, compared to bisection search.
+ val target = prefix
+ var i = fullStart
+ while (i < n1) {
+ val x = items(i)
+ if (x == target) {
+ if (!matched) {
+ newStart = i
+ matched = true
+ }
+ if (items(i + 1) != 0) {
+ newPartialStarts += i + 1
+ }
+ }
+ i += 1
+ }
+ }
+ new Postfix(items, newStart, newPartialStarts.result())
+ }
+
+ /**
+ * Projects this postfix with respect to the input prefix.
+ */
+ private def project(prefix: Array[Int]): Postfix = {
+ var partial = true
+ var cur = this
+ var i = 0
+ val np = prefix.length
+ while (i < np && cur.nonEmpty) {
+ val x = prefix(i)
+ if (x == 0) {
+ partial = false
+ } else {
+ if (partial) {
+ cur = cur.project(-x)
+ } else {
+ cur = cur.project(x)
+ partial = true
+ }
+ }
+ i += 1
+ }
+ cur
+ }
+
+ /**
+ * Projects this postfix with respect to the input prefix.
+ */
+ def project(prefix: Prefix): Postfix = project(prefix.items)
+
+ /**
+ * Returns the same sequence with compressed storage if possible.
+ */
+ def compressed: Postfix = {
+ if (start > 0) {
+ new Postfix(items.slice(start, items.length), 0, partialStarts.map(_ - start))
+ } else {
+ this
+ }
}
}
+
+ /**
+ * Represents a frequence sequence.
+ * @param sequence a sequence of itemsets stored as an Array of Arrays
+ * @param freq frequency
+ * @tparam Item item type
+ */
+ class FreqSequence[Item](val sequence: Array[Array[Item]], val freq: Long) extends Serializable {
+ /**
+ * Returns sequence as a Java List of lists for Java users.
+ */
+ def javaSequence: ju.List[ju.List[Item]] = sequence.map(_.toList.asJava).toList.asJava
+ }
}
+
+/**
+ * Model fitted by [[PrefixSpan]]
+ * @param freqSequences frequent sequences
+ * @tparam Item item type
+ */
+class PrefixSpanModel[Item](val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
+ extends Serializable
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 88914fa875990..1139ce36d50b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -179,12 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
- val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
+ val values = row.getArray(5).toDoubleArray()
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
- val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
- val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
+ val colPtrs = row.getArray(3).toIntArray()
+ val rowIndices = row.getArray(4).toIntArray()
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
@@ -257,8 +257,7 @@ class DenseMatrix(
this(numRows, numCols, values, false)
override def equals(o: Any): Boolean = o match {
- case m: DenseMatrix =>
- m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
+ case m: Matrix => toBreeze == m.toBreeze
case _ => false
}
@@ -519,6 +518,11 @@ class SparseMatrix(
rowIndices: Array[Int],
values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false)
+ override def equals(o: Any): Boolean = o match {
+ case m: Matrix => toBreeze == m.toBreeze
+ case _ => false
+ }
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
index b416d50a5631e..cff5dbeee3e57 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala
@@ -31,5 +31,5 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp
* Represents QR factors.
*/
@Experimental
-case class QRDecomposition[UType, VType](Q: UType, R: VType)
+case class QRDecomposition[QType, RType](Q: QType, R: RType)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 89a1818db0d1d..df15d985c814c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -26,7 +26,7 @@ import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
@@ -159,15 +159,13 @@ sealed trait Vector extends Serializable {
}
/**
- * :: DeveloperApi ::
+ * :: AlphaComponent ::
*
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.DataFrame]].
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
-@DeveloperApi
-private[spark] class VectorUDT extends UserDefinedType[Vector] {
+@AlphaComponent
+class VectorUDT extends UserDefinedType[Vector] {
override def sqlType: StructType = {
// type: 0 = sparse, 1 = dense
@@ -209,11 +207,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
tpe match {
case 0 =>
val size = row.getInt(1)
- val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
- val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
+ val indices = row.getArray(2).toIntArray()
+ val values = row.getArray(3).toDoubleArray()
new SparseVector(size, indices, values)
case 1 =>
- val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
+ val values = row.getArray(3).toDoubleArray()
new DenseVector(values)
}
}
@@ -766,6 +764,30 @@ class SparseVector(
maxIdx
}
}
+
+ /**
+ * Create a slice of this vector based on the given indices.
+ * @param selectedIndices Unsorted list of indices into the vector.
+ * This does NOT do bound checking.
+ * @return New SparseVector with values in the order specified by the given indices.
+ *
+ * NOTE: The API needs to be discussed before making this public.
+ * Also, if we have a version assuming indices are sorted, we should optimize it.
+ */
+ private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
+ var currentIdx = 0
+ val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
+ val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
+ val i_v = if (iIdx >= 0) {
+ Iterator((currentIdx, this.values(iIdx)))
+ } else {
+ Iterator()
+ }
+ currentIdx += 1
+ i_v
+ }.unzip
+ new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
+ }
}
object SparseVector {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index ab7611fd077ef..8f0d1e4aa010a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
* @param gradient Gradient function to be used.
* @param updater Updater to be used to update weights after every iteration.
*/
-class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater)
+class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater)
extends Optimizer with Logging {
private var stepSize: Double = 1.0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 62da9f2ef22a3..64e4be0ebb97e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -153,6 +153,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * Sample mean of each dimension.
+ *
* @since 1.1.0
*/
override def mean: Vector = {
@@ -168,6 +170,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * Sample variance of each dimension.
+ *
* @since 1.1.0
*/
override def variance: Vector = {
@@ -193,11 +197,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * Sample size.
+ *
* @since 1.1.0
*/
override def count: Long = totalCnt
/**
+ * Number of nonzero elements in each dimension.
+ *
* @since 1.1.0
*/
override def numNonzeros: Vector = {
@@ -207,6 +215,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * Maximum value of each dimension.
+ *
* @since 1.1.0
*/
override def max: Vector = {
@@ -221,6 +231,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * Minimum value of each dimension.
+ *
* @since 1.1.0
*/
override def min: Vector = {
@@ -235,6 +247,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * L2 (Euclidian) norm of each dimension.
+ *
* @since 1.2.0
*/
override def normL2: Vector = {
@@ -252,6 +266,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
}
/**
+ * L1 norm of each dimension.
+ *
* @since 1.2.0
*/
override def normL1: Vector = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index f84502919e381..24fe48cb8f71f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat
import scala.annotation.varargs
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -178,6 +178,9 @@ object Statistics {
ChiSqTest.chiSquaredFeatures(data)
}
+ /** Java-friendly version of [[chiSqTest()]] */
+ def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd)
+
/**
* Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a
* continuous distribution. By comparing the largest difference between the empirical cumulative
@@ -212,4 +215,15 @@ object Statistics {
: KolmogorovSmirnovTestResult = {
KolmogorovSmirnovTest.testOneSample(data, distName, params: _*)
}
+
+ /** Java-friendly version of [[kolmogorovSmirnovTest()]] */
+ @varargs
+ def kolmogorovSmirnovTest(
+ data: JavaDoubleRDD,
+ distName: String,
+ params: java.lang.Double*): KolmogorovSmirnovTestResult = {
+ val javaParams = params.asInstanceOf[Seq[Double]]
+ KolmogorovSmirnovTest.testOneSample(data.rdd.asInstanceOf[RDD[Double]],
+ distName, javaParams: _*)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index a835f96d5d0e3..9ce6faa137c41 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._
@@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging {
false
}
+ // Prepare periodic checkpointers
+ val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+ val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+ treeStrategy.getCheckpointInterval, input.sparkContext)
+
timer.stop("init")
logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
- var data = input
// Initialize tree
timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ val firstTreeModel = new DecisionTree(treeStrategy).run(input)
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+ predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction
@@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging {
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+ if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1
- // pseudo-residual for second iteration
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
-
var m = 1
- while (m < numIterations) {
+ var doneLearning = false
+ while (m < numIterations && !doneLearning) {
+ // Update data with pseudo-residuals
+ val data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
+
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
- // Create partial model
+ // Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
- // Note: A model of type regression is used since we require raw prediction
- val partialModel = new GradientBoostedTreesModel(
- Regression, baseLearners.slice(0, m + 1),
- baseLearnerWeights.slice(0, m + 1))
predError = GradientBoostedTreesModel.updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+ predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean())
if (validate) {
@@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging {
validatePredError = GradientBoostedTreesModel.updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+ validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
- return new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
+ doneLearning = true
} else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
+ bestValidateError = currentValidateError
+ bestM = m + 1
}
}
- // Update data with pseudo-residuals
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
m += 1
}
@@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ predErrorCheckpointer.deleteAllCheckpoints()
+ validatePredErrorCheckpointer.deleteAllCheckpoints()
if (persistedInput) input.unpersist()
if (validate) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 2d6b01524ff3d..50fe2ac53da9d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the
* validation input between two iterations is less than the validationTol
- * then stop. Ignored when [[run]] is used.
+ * then stop. Ignored when
+ * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/
@Experimental
case class BoostingStrategy(
@@ -89,7 +90,7 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStategy(algo)
+ val treeStrategy = Strategy.defaultStrategy(algo)
treeStrategy.maxDepth = 3
algo match {
case Algo.Classification =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index ada227c200a79..de2c784809443 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -178,14 +178,14 @@ object Strategy {
* @param algo "Classification" or "Regression"
*/
def defaultStrategy(algo: String): Strategy = {
- defaultStategy(Algo.fromString(algo))
+ defaultStrategy(Algo.fromString(algo))
}
/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
- def defaultStategy(algo: Algo): Strategy = algo match {
+ def defaultStrategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
@@ -193,4 +193,8 @@ object Strategy {
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
+
+ @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0")
+ def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo)
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 380291ac22bd3..9fe264656ede7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging {
// based on the number of training examples.
if (strategy.categoricalFeaturesInfo.nonEmpty) {
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
+ val maxCategory =
+ strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
require(maxCategoriesPerFeature <= maxPossibleBins,
- s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
- s"in categorical features (= $maxCategoriesPerFeature)")
+ s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
+ s"number of values in each categorical feature, but categorical feature $maxCategory " +
+ s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
+ "features with a large number of values, or add more training examples.")
}
val unorderedFeatures = new mutable.HashSet[Int]()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 5ac10f3fd32dd..0768204c33914 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 19d318203c344..d0077db6832e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 578749d85a4e6..86cee7e430b0a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
+private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 7104a7fa4dd4c..04d0cd24e6632 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -98,7 +98,7 @@ private[tree] class VarianceAggregator()
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
require(stats.size == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index dc9e0f9f51ffb..508bf9c1bdb47 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
/**
* :: DeveloperApi ::
@@ -66,7 +67,6 @@ class InformationGainStats(
}
}
-
private[spark] object InformationGainStats {
/**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
@@ -76,3 +76,62 @@ private[spark] object InformationGainStats {
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
+
+/**
+ * :: DeveloperApi ::
+ * Impurity statistics for each split
+ * @param gain information gain value
+ * @param impurity current node impurity
+ * @param impurityCalculator impurity statistics for current node
+ * @param leftImpurityCalculator impurity statistics for left child node
+ * @param rightImpurityCalculator impurity statistics for right child node
+ * @param valid whether the current split satisfies minimum info gain or
+ * minimum number of instances per node
+ */
+@DeveloperApi
+private[spark] class ImpurityStats(
+ val gain: Double,
+ val impurity: Double,
+ val impurityCalculator: ImpurityCalculator,
+ val leftImpurityCalculator: ImpurityCalculator,
+ val rightImpurityCalculator: ImpurityCalculator,
+ val valid: Boolean = true) extends Serializable {
+
+ override def toString: String = {
+ s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
+ s"right impurity = $rightImpurity"
+ }
+
+ def leftImpurity: Double = if (leftImpurityCalculator != null) {
+ leftImpurityCalculator.calculate()
+ } else {
+ -1.0
+ }
+
+ def rightImpurity: Double = if (rightImpurityCalculator != null) {
+ rightImpurityCalculator.calculate()
+ } else {
+ -1.0
+ }
+}
+
+private[spark] object ImpurityStats {
+
+ /**
+ * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
+ * denote that current split doesn't satisfies minimum info gain or
+ * minimum number of instances per node.
+ */
+ def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
+ new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
+ impurityCalculator, null, null, false)
+ }
+
+ /**
+ * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
+ * that only 'impurity' and 'impurityCalculator' are defined.
+ */
+ def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
+ new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 7c5cfa7bd84ce..11ed23176fc12 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -64,6 +64,7 @@ object MLUtils {
* feature dimensions.
* @param minPartitions min number of partitions
* @return labeled data stored as an RDD[LabeledPoint]
+ * @since 1.0.0
*/
def loadLibSVMFile(
sc: SparkContext,
@@ -114,6 +115,9 @@ object MLUtils {
// Convenient methods for `loadLibSVMFile`.
+ /**
+ * @since 1.0.0
+ */
@deprecated("use method without multiclass argument, which no longer has effect", "1.1.0")
def loadLibSVMFile(
sc: SparkContext,
@@ -126,6 +130,7 @@ object MLUtils {
/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
* partitions.
+ * @since 1.0.0
*/
def loadLibSVMFile(
sc: SparkContext,
@@ -133,6 +138,9 @@ object MLUtils {
numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions)
+ /**
+ * @since 1.0.0
+ */
@deprecated("use method without multiclass argument, which no longer has effect", "1.1.0")
def loadLibSVMFile(
sc: SparkContext,
@@ -141,6 +149,9 @@ object MLUtils {
numFeatures: Int): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, numFeatures)
+ /**
+ * @since 1.0.0
+ */
@deprecated("use method without multiclass argument, which no longer has effect", "1.1.0")
def loadLibSVMFile(
sc: SparkContext,
@@ -151,6 +162,7 @@ object MLUtils {
/**
* Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of
* features determined automatically and the default number of partitions.
+ * @since 1.0.0
*/
def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMFile(sc, path, -1)
@@ -181,12 +193,14 @@ object MLUtils {
* @param path file or directory path in any Hadoop-supported file system URI
* @param minPartitions min number of partitions
* @return vectors stored as an RDD[Vector]
+ * @since 1.1.0
*/
def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] =
sc.textFile(path, minPartitions).map(Vectors.parse)
/**
* Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions.
+ * @since 1.1.0
*/
def loadVectors(sc: SparkContext, path: String): RDD[Vector] =
sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse)
@@ -197,6 +211,7 @@ object MLUtils {
* @param path file or directory path in any Hadoop-supported file system URI
* @param minPartitions min number of partitions
* @return labeled points stored as an RDD[LabeledPoint]
+ * @since 1.1.0
*/
def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] =
sc.textFile(path, minPartitions).map(LabeledPoint.parse)
@@ -204,6 +219,7 @@ object MLUtils {
/**
* Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of
* partitions.
+ * @since 1.1.0
*/
def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] =
loadLabeledPoints(sc, dir, sc.defaultMinPartitions)
@@ -220,6 +236,7 @@ object MLUtils {
*
* @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and
* [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading.
+ * @since 1.0.0
*/
@deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1")
def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
@@ -241,6 +258,7 @@ object MLUtils {
*
* @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and
* [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading.
+ * @since 1.0.0
*/
@deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1")
def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
@@ -253,6 +271,7 @@ object MLUtils {
* Return a k element array of pairs of RDDs with the first element of each pair
* containing the training data, a complement of the validation data and the second
* element, the validation data, containing a unique 1/kth of the data. Where k=numFolds.
+ * @since 1.0.0
*/
@Experimental
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
@@ -268,6 +287,7 @@ object MLUtils {
/**
* Returns a new vector with `1.0` (bias) appended to the input vector.
+ * @since 1.0.0
*/
def appendBias(vector: Vector): Vector = {
vector match {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f75e024a713ee..618b95b9bd126 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -87,6 +87,8 @@ public void logisticRegressionWithSetters() {
LogisticRegression parent = (LogisticRegression) model.parent();
assert(parent.getMaxIter() == 10);
assert(parent.getRegParam() == 1.0);
+ assert(parent.getThresholds()[0] == 0.4);
+ assert(parent.getThresholds()[1] == 0.6);
assert(parent.getThreshold() == 0.6);
assert(model.getThreshold() == 0.6);
@@ -147,4 +149,13 @@ public void logisticRegressionPredictorClassifierMethods() {
}
}
}
+
+ @Test
+ public void logisticRegressionTrainingSummary() {
+ LogisticRegression lr = new LogisticRegression();
+ LogisticRegressionModel model = lr.fit(dataset);
+
+ LogisticRegressionTrainingSummary summary = model.summary();
+ assert(summary.totalIterations() == summary.objectiveHistory().length);
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 09a9fba0c19cf..a700c9cddb206 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -68,7 +68,7 @@ public void naiveBayesDefaultParams() {
assert(nb.getLabelCol() == "label");
assert(nb.getFeaturesCol() == "features");
assert(nb.getPredictionCol() == "prediction");
- assert(nb.getLambda() == 1.0);
+ assert(nb.getSmoothing() == 1.0);
assert(nb.getModelType() == "multinomial");
}
@@ -89,7 +89,7 @@ public void testNaiveBayes() {
});
DataFrame dataset = jsql.createDataFrame(jrdd, schema);
- NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
+ NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 32d0b3856b7e2..a66a1e12927be 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -29,6 +29,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -85,6 +86,7 @@ public void runDT() {
model.toDebugString();
model.trees();
model.treeWeights();
+ Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index e306ebadfe7cf..a00ce5e249c34 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -29,6 +29,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
@@ -85,6 +86,7 @@ public void runDT() {
model.toDebugString();
model.trees();
model.treeWeights();
+ Vector importances = model.featureImportances();
/*
// TODO: Add test once save/load are implemented. SPARK-6725
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index d272a42c8576f..427be9430d820 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -124,6 +124,10 @@ public Boolean call(Tuple2 tuple2) {
}
});
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
+
+ // Check: javaTopTopicsPerDocuments
+ JavaRDD> topTopics =
+ model.javaTopTopicsPerDocument(3);
}
@Test
@@ -160,11 +164,31 @@ public void OnlineOptimizerCompatibility() {
assertEquals(roundedLocalTopicSummary.length, k);
}
+ @Test
+ public void localLdaMethods() {
+ JavaRDD> docs = sc.parallelize(toyData, 2);
+ JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs);
+
+ // check: topicDistributions
+ assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count());
+
+ // check: logPerplexity
+ double logPerplexity = toyModel.logPerplexity(pairedDocs);
+
+ // check: logLikelihood.
+ ArrayList> docsSingleWord = new ArrayList>();
+ docsSingleWord.add(new Tuple2(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0)));
+ JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
+ double logLikelihood = toyModel.logLikelihood(single);
+ }
+
private static int tinyK = LDASuite$.MODULE$.tinyK();
private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
private static Tuple2[] tinyTopicDescription =
LDASuite$.MODULE$.tinyTopicDescription();
private JavaPairRDD corpus;
+ private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel();
+ private ArrayList> toyData = LDASuite$.MODULE$.javaToyData();
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index effc8a1a6dabc..fa4d334801ce4 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -18,12 +18,12 @@
package org.apache.spark.mllib.evaluation;
import java.io.Serializable;
-import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import scala.Tuple2;
import scala.Tuple2$;
-import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -34,18 +34,18 @@
public class JavaRankingMetricsSuite implements Serializable {
private transient JavaSparkContext sc;
- private transient JavaRDD, ArrayList>> predictionAndLabels;
+ private transient JavaRDD, List>> predictionAndLabels;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
- predictionAndLabels = sc.parallelize(Lists.newArrayList(
+ predictionAndLabels = sc.parallelize(Arrays.asList(
Tuple2$.MODULE$.apply(
- Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)),
+ Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)),
Tuple2$.MODULE$.apply(
- Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)),
+ Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
Tuple2$.MODULE$.apply(
- Lists.newArrayList(1, 2, 3, 4, 5), Lists.newArrayList())), 2);
+ Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2);
}
@After
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
new file mode 100644
index 0000000000000..34daf5fbde80f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+
+public class JavaPrefixSpanSuite {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaPrefixSpan");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runPrefixSpan() {
+ JavaRDD>> sequences = sc.parallelize(Arrays.asList(
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
+ Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
+ Arrays.asList(Arrays.asList(6))
+ ), 2);
+ PrefixSpan prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5);
+ PrefixSpanModel model = prefixSpan.run(sequences);
+ JavaRDD> freqSeqs = model.freqSequences().toJavaRDD();
+ List> localFreqSeqs = freqSeqs.collect();
+ Assert.assertEquals(5, localFreqSeqs.size());
+ // Check that each frequent sequence could be materialized.
+ for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) {
+ List> seq = freqSeq.javaSequence();
+ long freq = freqSeq.freq();
+ }
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 62f7f26b7c98f..eb4e3698624bc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -27,7 +27,12 @@
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.stat.test.ChiSqTestResult;
+import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -53,4 +58,21 @@ public void testCorr() {
// Check default method
assertEquals(corr1, corr2);
}
+
+ @Test
+ public void kolmogorovSmirnovTest() {
+ JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0));
+ KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
+ KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
+ data, "norm", 0.0, 1.0);
+ }
+
+ @Test
+ public void chiSqTest() {
+ JavaRDD data = sc.parallelize(Lists.newArrayList(
+ new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
+ new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
+ new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
+ ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 63d2fa31c7499..1f2c9b75b617b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.sql.DataFrame
class PipelineSuite extends SparkFunSuite {
@@ -65,6 +66,8 @@ class PipelineSuite extends SparkFunSuite {
.setStages(Array(estimator0, transformer1, estimator2, transformer3))
val pipelineModel = pipeline.fit(dataset0)
+ MLTestingUtils.checkCopy(pipelineModel)
+
assert(pipelineModel.stages.length === 4)
assert(pipelineModel.stages(0).eq(model0))
assert(pipelineModel.stages(1).eq(transformer1))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
new file mode 100644
index 0000000000000..1292e57d7c01a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+
+class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ // TODO: test for weights comparison with Weka MLP
+ test("ANN with Sigmoid learns XOR function with LBFGS optimizer") {
+ val inputs = Array(
+ Array(0.0, 0.0),
+ Array(0.0, 1.0),
+ Array(1.0, 0.0),
+ Array(1.0, 1.0)
+ )
+ val outputs = Array(0.0, 1.0, 1.0, 0.0)
+ val data = inputs.zip(outputs).map { case (features, label) =>
+ (Vectors.dense(features), Vectors.dense(label))
+ }
+ val rddData = sc.parallelize(data, 1)
+ val hiddenLayersTopology = Array(5)
+ val dataSample = rddData.first()
+ val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
+ val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
+ val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val trainer = new FeedForwardTrainer(topology, 2, 1)
+ trainer.setWeights(initialWeights)
+ trainer.LBFGSOptimizer.setNumIterations(20)
+ val model = trainer.train(rddData)
+ val predictionAndLabels = rddData.map { case (input, label) =>
+ (model.predict(input)(0), label(0))
+ }.collect()
+ predictionAndLabels.foreach { case (p, l) =>
+ assert(math.round(p) === l)
+ }
+ }
+
+ test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") {
+ val inputs = Array(
+ Array(0.0, 0.0),
+ Array(0.0, 1.0),
+ Array(1.0, 0.0),
+ Array(1.0, 1.0)
+ )
+ val outputs = Array(
+ Array(1.0, 0.0),
+ Array(0.0, 1.0),
+ Array(0.0, 1.0),
+ Array(1.0, 0.0)
+ )
+ val data = inputs.zip(outputs).map { case (features, label) =>
+ (Vectors.dense(features), Vectors.dense(label))
+ }
+ val rddData = sc.parallelize(data, 1)
+ val hiddenLayersTopology = Array(5)
+ val dataSample = rddData.first()
+ val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
+ val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
+ val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val trainer = new FeedForwardTrainer(topology, 2, 2)
+ trainer.SGDOptimizer.setNumIterations(2000)
+ trainer.setWeights(initialWeights)
+ val model = trainer.train(rddData)
+ val predictionAndLabels = rddData.map { case (input, label) =>
+ (model.predict(input), label)
+ }.collect()
+ predictionAndLabels.foreach { case (p, l) =>
+ assert(p ~== l absTol 0.5)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 73b4805c4c597..4b7c5d3f23d2c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,12 +21,14 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -57,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
- val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
ParamsSuite.checkParams(model)
}
@@ -231,6 +233,34 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
}
+ test("predictRaw and predictProbability") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+
+ val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(newTree)
+
+ val predictions = newTree.transform(newData)
+ .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
+ .collect()
+
+ predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ assert(pred === rawPred.argmax,
+ s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+ val sum = rawPred.toArray.sum
+ assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+ "probability prediction mismatch")
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 82c345491bb3c..e3909bccaa5ca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,12 +22,14 @@ import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
/**
@@ -57,7 +59,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
- Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
Array(1.0))
ParamsSuite.checkParams(model)
}
@@ -76,6 +78,28 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val gbt = new GBTClassifier()
+ .setMaxDepth(2)
+ .setLossType("logistic")
+ .setMaxIter(5)
+ .setStepSize(0.1)
+ .setCheckpointInterval(2)
+ val model = gbt.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b7dd44753896a..cce39f382f738 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -91,11 +92,53 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.hasParent)
}
+ test("setThreshold, getThreshold") {
+ val lr = new LogisticRegression
+ // default
+ assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5")
+ withClue("LogisticRegression should not have thresholds set by default.") {
+ intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future
+ lr.getThresholds
+ }
+ }
+ // Set via threshold.
+ // Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
+ lr.setThreshold(1.0)
+ assert(lr.getThresholds === Array(0.0, 1.0))
+ lr.setThreshold(0.0)
+ assert(lr.getThresholds === Array(1.0, 0.0))
+ lr.setThreshold(0.5)
+ assert(lr.getThresholds === Array(0.5, 0.5))
+ // Set via thresholds
+ val lr2 = new LogisticRegression
+ lr2.setThresholds(Array(0.3, 0.7))
+ val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
+ assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7)
+ // thresholds and threshold must be consistent
+ lr2.setThresholds(Array(0.1, 0.2, 0.3))
+ withClue("getThreshold should throw error if thresholds has length != 2.") {
+ intercept[IllegalArgumentException] {
+ lr2.getThreshold
+ }
+ }
+ // thresholds and threshold must be consistent: values
+ withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
+ intercept[IllegalArgumentException] {
+ val lr2model = lr2.fit(dataset,
+ lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
+ lr2model.getThreshold
+ }
+ }
+ }
+
test("logistic regression doesn't fit intercept when fitIntercept is off") {
val lr = new LogisticRegression
lr.setFitIntercept(false)
val model = lr.fit(dataset)
assert(model.intercept === 0.0)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
}
test("logistic regression with setters") {
@@ -123,14 +166,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
// Call transform with params, and check that the params worked.
val predNotAllZero =
- model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+ model.transform(dataset, model.threshold -> 0.0,
+ model.probabilityCol -> "myProb")
.select("prediction", "myProb")
.collect()
.map { case Row(pred: Double, prob: Vector) => pred }
assert(predNotAllZero.exists(_ !== 0.0))
// Call fit() with new params, and check as many params as we can.
- val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+ lr.setThresholds(Array(0.6, 0.4))
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
lr.probabilityCol -> "theProb")
val parent2 = model2.parent.asInstanceOf[LogisticRegression]
assert(parent2.getMaxIter === 5)
@@ -699,6 +744,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
assert(model1.intercept ~== interceptR relTol 1E-5)
- assert(model1.weights ~= weightsR absTol 1E-6)
+ assert(model1.weights ~== weightsR absTol 1E-6)
+ }
+
+ test("evaluate on test set") {
+ // Evaluate on test set should be same as that of the transformed training data.
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ val model = lr.fit(dataset)
+ val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ assert(summary.areaUnderROC === sameSummary.areaUnderROC)
+ assert(summary.roc.collect() === sameSummary.roc.collect())
+ assert(summary.pr.collect === sameSummary.pr.collect())
+ assert(
+ summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
+ assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+ assert(
+ summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+ }
+
+ test("statistics on training data") {
+ // Test that loss is monotonically decreasing.
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ val model = lr.fit(dataset)
+ assert(
+ model.summary
+ .objectiveHistory
+ .sliding(2)
+ .forall(x => x(0) >= x(1)))
+
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
new file mode 100644
index 0000000000000..ddc948f65df45
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.Row
+
+class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("XOR function learning as binary classification problem with two outputs.") {
+ val dataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(0.0, 0.0), 0.0),
+ (Vectors.dense(0.0, 1.0), 1.0),
+ (Vectors.dense(1.0, 0.0), 1.0),
+ (Vectors.dense(1.0, 1.0), 0.0))
+ ).toDF("features", "label")
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(11L)
+ .setMaxIter(100)
+ val model = trainer.fit(dataFrame)
+ val result = model.transform(dataFrame)
+ val predictionAndLabels = result.select("prediction", "label").collect()
+ predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
+ assert(p == l)
+ }
+ }
+
+ // TODO: implement a more rigorous test
+ test("3 class classification with 2 hidden layers") {
+ val nPoints = 1000
+
+ // The following weights are taken from OneVsRestSuite.scala
+ // they represent 3-class iris dataset
+ val weights = Array(
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ val rdd = sc.parallelize(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42), 2)
+ val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
+ val numClasses = 3
+ val numIterations = 100
+ val layers = Array[Int](4, 5, 4, numClasses)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(11L)
+ .setMaxIter(numIterations)
+ val model = trainer.fit(dataFrame)
+ val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
+ .map { case Row(p: Double, l: Double) => (p, l) }
+ // train multinomial logistic regression
+ val lr = new LogisticRegressionWithLBFGS()
+ .setIntercept(true)
+ .setNumClasses(numClasses)
+ lr.optimizer.setRegParam(0.0)
+ .setNumIterations(numIterations)
+ val lrModel = lr.run(rdd)
+ val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label))
+ // MLP's predictions should not differ a lot from LR's.
+ val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels)
+ val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
+ assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 76381a2741296..98bc9511163e7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -17,8 +17,11 @@
package org.apache.spark.ml.classification
+import breeze.linalg.{Vector => BV}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -46,6 +49,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
}
+ def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
+ val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
+ val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def validateProbabilities(
+ featureAndProbabilities: DataFrame,
+ model: NaiveBayesModel,
+ modelType: String): Unit = {
+ featureAndProbabilities.collect().foreach {
+ case Row(features: Vector, probability: Vector) => {
+ assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
+ val expected = modelType match {
+ case Multinomial =>
+ expectedMultinomialProbabilities(model, features)
+ case Bernoulli =>
+ expectedBernoulliProbabilities(model, features)
+ case _ =>
+ throw new UnknownError(s"Invalid modelType: $modelType.")
+ }
+ assert(probability ~== expected relTol 1.0e-10)
+ }
+ }
+ }
+
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -58,7 +98,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(nb.getLabelCol === "label")
assert(nb.getFeaturesCol === "features")
assert(nb.getPredictionCol === "prediction")
- assert(nb.getLambda === 1.0)
+ assert(nb.getSmoothing === 1.0)
assert(nb.getModelType === "multinomial")
}
@@ -75,7 +115,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 42, "multinomial"))
- val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
+ val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset)
validateModelFit(pi, theta, model)
@@ -83,9 +123,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "multinomial")
}
test("Naive Bayes Bernoulli") {
@@ -101,7 +145,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 45, "bernoulli"))
- val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
+ val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset)
validateModelFit(pi, theta, model)
@@ -109,8 +153,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 3775292f6dca7..977f0e0b70c1a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils}
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -70,6 +70,10 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
val ovaModel = ova.fit(dataset)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(ovaModel)
+
assert(ovaModel.models.size === numClasses)
val transformedDataset = ovaModel.transform(dataset)
@@ -151,7 +155,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
"copy should handle extra classifier params")
- val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
+ val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1)))
ovrModel.models.foreach { case m: LogisticRegressionModel =>
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
new file mode 100644
index 0000000000000..8f50cb924e64d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+final class TestProbabilisticClassificationModel(
+ override val uid: String,
+ override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
+
+ override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra)
+
+ override protected def predictRaw(input: Vector): Vector = {
+ input
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction
+ }
+
+ def friendlyPredict(input: Vector): Double = {
+ predict(input)
+ }
+}
+
+
+class ProbabilisticClassifierSuite extends SparkFunSuite {
+
+ test("test thresholding") {
+ val thresholds = Array(0.5, 0.2)
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
+ assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
+ assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+ }
+
+ test("test thresholding not required") {
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2)
+ assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ab711c8e4b215..b4403ec30049a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,11 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -66,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
ParamsSuite.checkParams(model)
}
@@ -121,6 +123,65 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
}
+ test("predictRaw and predictProbability") {
+ val rdd = orderedLabeledPoints5_20
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setSeed(123)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val model = rf.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
+ val predictions = model.transform(df)
+ .select(rf.getPredictionCol, rf.getRawPredictionCol, rf.getProbabilityCol)
+ .collect()
+
+ predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ assert(pred === rawPred.argmax,
+ s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+ val sum = rawPred.toArray.sum
+ assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+ "probability prediction mismatch")
+ assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("all")
+ .setSubsamplingRate(1.0)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = rf.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -173,13 +234,5 @@ private object RandomForestClassifierSuite {
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
assert(newModel.numClasses == numClasses)
- val results = newModel.transform(newData)
- results.select("rawPrediction", "prediction").collect().foreach {
- case Row(raw: Vector, prediction: Double) => {
- assert(raw.size == numClasses)
- val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
- assert(predFromRaw == prediction)
- }
- }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 1f15ac02f4008..688b0e31f91dc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -52,10 +52,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(kmeans.getFeaturesCol === "features")
assert(kmeans.getPredictionCol === "prediction")
assert(kmeans.getMaxIter === 20)
- assert(kmeans.getRuns === 1)
assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
assert(kmeans.getInitSteps === 5)
- assert(kmeans.getEpsilon === 1e-4)
+ assert(kmeans.getTol === 1e-4)
}
test("set parameters") {
@@ -64,21 +63,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
.setFeaturesCol("test_feature")
.setPredictionCol("test_prediction")
.setMaxIter(33)
- .setRuns(7)
.setInitMode(MLlibKMeans.RANDOM)
.setInitSteps(3)
.setSeed(123)
- .setEpsilon(1e-3)
+ .setTol(1e-3)
assert(kmeans.getK === 9)
assert(kmeans.getFeaturesCol === "test_feature")
assert(kmeans.getPredictionCol === "test_prediction")
assert(kmeans.getMaxIter === 33)
- assert(kmeans.getRuns === 7)
assert(kmeans.getInitMode === MLlibKMeans.RANDOM)
assert(kmeans.getInitSteps === 3)
assert(kmeans.getSeed === 123)
- assert(kmeans.getEpsilon === 1e-3)
+ assert(kmeans.getTol === 1e-3)
}
test("parameters validation") {
@@ -91,9 +88,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
intercept[IllegalArgumentException] {
new KMeans().setInitSteps(0)
}
- intercept[IllegalArgumentException] {
- new KMeans().setRuns(0)
- }
}
test("fit & transform") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
new file mode 100644
index 0000000000000..6d8412b0b3701
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+
+class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+
+ test("params") {
+ ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index ec85e0d151e07..0eba34fda6228 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index c452054bec92f..c04dda41eea34 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
@@ -51,6 +52,9 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
.foreach { case Row(vector1: Vector, vector2: Vector) =>
assert(vector1.equals(vector2), "Transformed vector is different with expected.")
}
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
}
test("MinMaxScaler arguments max must be larger than min") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 65846a846b7b4..321eeb843941c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1))
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}
test("input column without ML attribute") {
@@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
val output = encoder.transform(df)
val group = AttributeGroup.fromStructField(output.schema("encoded"))
assert(group.size === 2)
- assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0))
- assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1))
+ assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
+ assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index d0ae36b28c7a9..30c500f87a769 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -56,6 +57,9 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
.setK(3)
.fit(df)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(pca)
+
pca.transform(df).select("pca_features", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 8148c553e9051..6aed3243afce8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}
+
+ test("attribute generation") {
+ val formula = new RFormula().setFormula("id ~ a + b")
+ val original = sqlContext.createDataFrame(
+ Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
+ ).toDF("id", "a", "b")
+ val model = formula.fit(original)
+ val result = model.transform(original)
+ val attrs = AttributeGroup.fromStructField(result.schema("features"))
+ val expectedAttrs = new AttributeGroup(
+ "features",
+ Array(
+ new BinaryAttribute(Some("a__bar"), Some(1)),
+ new BinaryAttribute(Some("a__foo"), Some(2)),
+ new NumericAttribute(Some("b"), Some(3))))
+ assert(attrs === expectedAttrs)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
new file mode 100644
index 0000000000000..d19052881ae45
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ ParamsSuite.checkParams(new SQLTransformer())
+ }
+
+ test("transform numeric data") {
+ val original = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+ val sqlTrans = new SQLTransformer().setStatement(
+ "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
+ val result = sqlTrans.transform(original)
+ val resultSchema = sqlTrans.transformSchema(original.schema)
+ val expected = sqlContext.createDataFrame(
+ Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
+ .toDF("id", "v1", "v2", "v3", "v4")
+ assert(result.schema.toString == resultSchema.toString)
+ assert(resultSchema == expected.schema)
+ assert(result.collect().toSeq == expected.collect().toSeq)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
new file mode 100644
index 0000000000000..f01306f89cb5f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+object StopWordsRemoverSuite extends SparkFunSuite {
+ def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
+ t.transform(dataset)
+ .select("filtered", "expected")
+ .collect()
+ .foreach { case Row(tokens, wantedTokens) =>
+ assert(tokens === wantedTokens)
+ }
+ }
+}
+
+class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import StopWordsRemoverSuite._
+
+ test("StopWordsRemover default") {
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("test", "test"), Seq("test", "test")),
+ (Seq("a", "b", "c", "d"), Seq("b", "c", "d")),
+ (Seq("a", "the", "an"), Seq()),
+ (Seq("A", "The", "AN"), Seq()),
+ (Seq(null), Seq(null)),
+ (Seq(), Seq())
+ )).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+
+ test("StopWordsRemover case sensitive") {
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setCaseSensitive(true)
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("A"), Seq("A")),
+ (Seq("The", "the"), Seq("The"))
+ )).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+
+ test("StopWordsRemover with additional words") {
+ val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala")
+ val remover = new StopWordsRemover()
+ .setInputCol("raw")
+ .setOutputCol("filtered")
+ .setStopWords(stopWords)
+ val dataSet = sqlContext.createDataFrame(Seq(
+ (Seq("python", "scala", "a"), Seq()),
+ (Seq("Python", "Scala", "swift"), Seq("swift"))
+ )).toDF("raw", "expected")
+
+ testStopWordsRemover(remover, dataSet)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 99f82bea42688..fa918ce64877c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -37,6 +40,10 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(indexer)
+
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
@@ -49,6 +56,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(output === expected)
}
+ test("StringIndexerUnseen") {
+ val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
+ val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
+ val df = sqlContext.createDataFrame(data).toDF("id", "label")
+ val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .fit(df)
+ // Verify we throw by default with unseen values
+ intercept[SparkException] {
+ indexer.transform(df2).collect()
+ }
+ val indexerSkipInvalid = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ .setHandleInvalid("skip")
+ .fit(df)
+ // Verify that we skip the c record
+ val transformed = indexerSkipInvalid.transform(df2)
+ val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attr.values.get === Array("b", "a"))
+ val output = transformed.select("id", "labelIndex").map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0
+ val expected = Set((0, 1.0), (1, 0.0))
+ assert(output === expected)
+ }
+
test("StringIndexer with a numeric input column") {
val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
@@ -75,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val df = sqlContext.range(0L, 10L)
assert(indexerModel.transform(df).eq(df))
}
+
+ test("IndexToString params") {
+ val idxToStr = new IndexToString()
+ ParamsSuite.checkParams(idxToStr)
+ }
+
+ test("IndexToString.transform") {
+ val labels = Array("a", "b", "c")
+ val df0 = sqlContext.createDataFrame(Seq(
+ (0, "a"), (1, "b"), (2, "c"), (0, "a")
+ )).toDF("index", "expected")
+
+ val idxToStr0 = new IndexToString()
+ .setInputCol("index")
+ .setOutputCol("actual")
+ .setLabels(labels)
+ idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
+ case Row(actual, expected) =>
+ assert(actual === expected)
+ }
+
+ val attr = NominalAttribute.defaultAttr.withValues(labels)
+ val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected"))
+
+ val idxToStr1 = new IndexToString()
+ .setInputCol("indexWithAttr")
+ .setOutputCol("actual")
+ idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
+ case Row(actual, expected) =>
+ assert(actual === expected)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 03120c828ca96..8cb0a2cf14d37 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
@@ -109,6 +110,10 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L
test("Throws error when given RDDs with different size vectors") {
val vectorIndexer = getIndexer
val model = vectorIndexer.fit(densePoints1) // vectors of length 3
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
intercept[SparkException] {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
new file mode 100644
index 0000000000000..a6c2fba8360dd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ val slicer = new VectorSlicer
+ ParamsSuite.checkParams(slicer)
+ assert(slicer.getIndices.length === 0)
+ assert(slicer.getNames.length === 0)
+ withClue("VectorSlicer should not have any features selected by default") {
+ intercept[IllegalArgumentException] {
+ slicer.validateParams()
+ }
+ }
+ }
+
+ test("feature validity checks") {
+ import VectorSlicer._
+ assert(validIndices(Array(0, 1, 8, 2)))
+ assert(validIndices(Array.empty[Int]))
+ assert(!validIndices(Array(-1)))
+ assert(!validIndices(Array(1, 2, 1)))
+
+ assert(validNames(Array("a", "b")))
+ assert(validNames(Array.empty[String]))
+ assert(!validNames(Array("", "b")))
+ assert(!validNames(Array("a", "b", "a")))
+ }
+
+ test("Test vector slicer") {
+ val sqlContext = new SQLContext(sc)
+
+ val data = Array(
+ Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
+ Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
+ Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0),
+ Vectors.dense(0.6, -1.1, -3.0, 4.5, 3.3),
+ Vectors.sparse(5, Seq())
+ )
+
+ // Expected after selecting indices 1, 4
+ val expected = Array(
+ Vectors.sparse(2, Seq((0, 2.3))),
+ Vectors.dense(2.3, 1.0),
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(-1.1, 3.3),
+ Vectors.sparse(2, Seq())
+ )
+
+ val defaultAttr = NumericAttribute.defaultAttr
+ val attrs = Array("f0", "f1", "f2", "f3", "f4").map(defaultAttr.withName)
+ val attrGroup = new AttributeGroup("features", attrs.asInstanceOf[Array[Attribute]])
+
+ val resultAttrs = Array("f1", "f4").map(defaultAttr.withName)
+ val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]])
+
+ val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) }
+ val df = sqlContext.createDataFrame(rdd,
+ StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField())))
+
+ val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result")
+
+ def validateResults(df: DataFrame): Unit = {
+ df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) =>
+ assert(vec1 === vec2)
+ }
+ val resultMetadata = AttributeGroup.fromStructField(df.schema("result"))
+ val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected"))
+ assert(resultMetadata.numAttributes === expectedMetadata.numAttributes)
+ resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) =>
+ assert(a === b)
+ }
+ }
+
+ vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
+ validateResults(vectorSlicer.transform(df))
+
+ vectorSlicer.setIndices(Array(1)).setNames(Array("f4"))
+ validateResults(vectorSlicer.transform(df))
+
+ vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
+ validateResults(vectorSlicer.transform(df))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index aa6ce533fd885..a2e46f2029956 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -62,10 +63,75 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
.setSeed(42L)
.fit(docDF)
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
model.transform(docDF).select("result", "expected").collect().foreach {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
}
}
+
+ test("getVectors") {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+
+ val codes = Map(
+ "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
+ "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
+ "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
+ )
+ val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }
+
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+
+ val realVectors = model.getVectors.sort("word").select("vector").map {
+ case Row(v: Vector) => v
+ }.collect()
+
+ realVectors.zip(expectedVectors).foreach {
+ case (real, expected) =>
+ assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.")
+ }
+ }
+
+ test("findSynonyms") {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+
+ val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)
+ val (synonyms, similarity) = model.findSynonyms("a", 2).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+
+ assert(synonyms.toArray === Array("b", "c"))
+ expectedSimilarity.zip(similarity).map {
+ case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
+ }
+
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 778abcba22c10..460849c79f04f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -124,4 +124,22 @@ private[ml] object TreeTests extends SparkFunSuite {
"checkEqual failed since the two tree ensembles were not identical")
}
}
+
+ /**
+ * Helper method for constructing a tree for testing.
+ * Given left, right children, construct a parent node.
+ * @param split Split for parent node
+ * @return Parent node with children attached
+ */
+ def buildParentNode(left: Node, right: Node, split: Split): Node = {
+ val leftImp = left.impurityStats
+ val rightImp = right.impurityStats
+ val parentImp = leftImp.copy.add(rightImp)
+ val leftWeight = leftImp.count / parentImp.count.toDouble
+ val rightWeight = rightImp.count / parentImp.count.toDouble
+ val gain = parentImp.calculate() -
+ (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
+ val pred = parentImp.predict
+ new InternalNode(pred, parentImp.calculate(), gain, left, right, split, parentImp)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 050d4170ea017..be95638d81686 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite {
val inArray = ParamValidators.inArray[Int](Array(1, 2))
assert(inArray(1) && inArray(2) && !inArray(0))
}
+
+ test("Params.copyValues") {
+ val t = new TestParams()
+ val t2 = t.copy(ParamMap.empty)
+ assert(!t2.isSet(t2.maxIter))
+ val t3 = t.copy(ParamMap(t.maxIter -> 20))
+ assert(t3.isSet(t3.maxIter))
+ }
}
object ParamsSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 2e5cfe7027eb6..eadc80e0e62b1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -28,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -374,6 +375,9 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
}
logInfo(s"Test RMSE is $rmse.")
assert(rmse < targetRMSE)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
}
test("exact rank-1 matrix") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 33aa9d0d62343..b092bcd6a7e86 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
@@ -61,6 +62,16 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
}
+ test("copied model must have the same parent") {
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val model = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(8).fit(df)
+ MLTestingUtils.checkCopy(model)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 9682edcd9ba84..a68197b59193d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,13 +19,15 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
/**
@@ -81,6 +83,9 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxDepth(2)
.setMaxIter(2)
val model = gbt.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
val preds = model.transform(df)
val predictions = preds.select("prediction").map(_.getDouble(0))
// Checks based on SPARK-8736 (to ensure it is not doing classification)
@@ -88,6 +93,24 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predictions.min() < -1)
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val df = sqlContext.createDataFrame(data)
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxIter(5)
+ .setStepSize(0.1)
+ .setCheckpointInterval(2)
+ val model = gbt.fit(df)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
new file mode 100644
index 0000000000000..c0ab00b68a2f3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+ private def generateIsotonicInput(labels: Seq[Double]): DataFrame = {
+ sqlContext.createDataFrame(
+ labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) }
+ ).toDF("label", "features", "weight")
+ }
+
+ private def generatePredictionInput(features: Seq[Double]): DataFrame = {
+ sqlContext.createDataFrame(features.map(Tuple1.apply))
+ .toDF("features")
+ }
+
+ test("isotonic regression predictions") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18))
+ val ir = new IsotonicRegression().setIsotonic(true)
+
+ val model = ir.fit(dataset)
+
+ val predictions = model
+ .transform(dataset)
+ .select("prediction").map { case Row(pred) =>
+ pred
+ }.collect()
+
+ assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+
+ assert(model.boundaries === Vectors.dense(0, 1, 3, 4, 5, 6, 7, 8))
+ assert(model.predictions === Vectors.dense(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0))
+ assert(model.getIsotonic)
+ }
+
+ test("antitonic regression predictions") {
+ val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1))
+ val ir = new IsotonicRegression().setIsotonic(false)
+
+ val model = ir.fit(dataset)
+ val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0))
+
+ val predictions = model
+ .transform(features)
+ .select("prediction").map {
+ case Row(pred) => pred
+ }.collect()
+
+ assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1))
+ }
+
+ test("params validation") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3))
+ val ir = new IsotonicRegression
+ ParamsSuite.checkParams(ir)
+ val model = ir.fit(dataset)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("default params") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3))
+ val ir = new IsotonicRegression()
+ assert(ir.getLabelCol === "label")
+ assert(ir.getFeaturesCol === "features")
+ assert(ir.getPredictionCol === "prediction")
+ assert(!ir.isDefined(ir.weightCol))
+ assert(ir.getIsotonic)
+ assert(ir.getFeatureIndex === 0)
+
+ val model = ir.fit(dataset)
+ model.transform(dataset)
+ .select("label", "features", "prediction", "weight")
+ .collect()
+
+ assert(model.getLabelCol === "label")
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(!model.isDefined(model.weightCol))
+ assert(model.getIsotonic)
+ assert(model.getFeatureIndex === 0)
+ assert(model.hasParent)
+ }
+
+ test("set parameters") {
+ val isotonicRegression = new IsotonicRegression()
+ .setIsotonic(false)
+ .setWeightCol("w")
+ .setFeaturesCol("f")
+ .setLabelCol("l")
+ .setPredictionCol("p")
+
+ assert(!isotonicRegression.getIsotonic)
+ assert(isotonicRegression.getWeightCol === "w")
+ assert(isotonicRegression.getFeaturesCol === "f")
+ assert(isotonicRegression.getLabelCol === "l")
+ assert(isotonicRegression.getPredictionCol === "p")
+ }
+
+ test("missing column") {
+ val dataset = generateIsotonicInput(Seq(1, 2, 3))
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().setWeightCol("w").fit(dataset)
+ }
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().setFeaturesCol("f").fit(dataset)
+ }
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().setLabelCol("l").fit(dataset)
+ }
+
+ intercept[IllegalArgumentException] {
+ new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset)
+ }
+ }
+
+ test("vector features column with feature index") {
+ val dataset = sqlContext.createDataFrame(Seq(
+ (4.0, Vectors.dense(0.0, 1.0)),
+ (3.0, Vectors.dense(0.0, 2.0)),
+ (5.0, Vectors.sparse(2, Array(1), Array(3.0))))
+ ).toDF("label", "features")
+
+ val ir = new IsotonicRegression()
+ .setFeatureIndex(1)
+
+ val model = ir.fit(dataset)
+
+ val features = generatePredictionInput(Seq(2.0, 3.0, 4.0, 5.0))
+
+ val predictions = model
+ .transform(features)
+ .select("prediction").map {
+ case Row(pred) => pred
+ }.collect()
+
+ assert(predictions === Array(3.5, 5.0, 5.0, 5.0))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 7cdda3db88ad1..2aaee71ecc734 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -70,7 +71,12 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(lir.getRegParam === 0.0)
assert(lir.getElasticNetParam === 0.0)
assert(lir.getFitIntercept)
+ assert(lir.getStandardization)
val model = lir.fit(dataset)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
model.transform(dataset)
.select("label", "prediction")
.collect()
@@ -81,8 +87,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("linear regression with intercept without regularization") {
- val trainer = new LinearRegression
- val model = trainer.fit(dataset)
+ val trainer1 = new LinearRegression
+ // The result should be the same regardless of standardization without regularization
+ val trainer2 = (new LinearRegression).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
Using the following R code to load the data and train the model using glmnet package.
@@ -95,28 +104,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
> weights
3 x 1 sparse Matrix of class "dgCMatrix"
s0
- (Intercept) 6.300528
- as.numeric.data.V2. 4.701024
- as.numeric.data.V3. 7.198257
+ (Intercept) 6.298698
+ as.numeric.data.V2. 4.700706
+ as.numeric.data.V3. 7.199082
*/
val interceptR = 6.298698
val weightsR = Vectors.dense(4.700706, 7.199082)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR relTol 1E-3)
+ assert(model1.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== interceptR relTol 1E-3)
+ assert(model2.weights ~= weightsR relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept without regularization") {
- val trainer = (new LinearRegression).setFitIntercept(false)
- val model = trainer.fit(dataset)
- val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
+ val trainer1 = (new LinearRegression).setFitIntercept(false)
+ // Without regularization the results should be the same
+ val trainer2 = (new LinearRegression).setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val modelWithoutIntercept1 = trainer1.fit(datasetWithoutIntercept)
+ val model2 = trainer2.fit(dataset)
+ val modelWithoutIntercept2 = trainer2.fit(datasetWithoutIntercept)
+
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
@@ -130,26 +147,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
val weightsR = Vectors.dense(6.995908, 5.275131)
- assert(model.intercept ~== 0 absTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== 0 absTol 1E-3)
+ assert(model1.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== 0 absTol 1E-3)
+ assert(model2.weights ~= weightsR relTol 1E-3)
+
/*
Then again with the data with no intercept:
> weightsWithoutIntercept
3 x 1 sparse Matrix of class "dgCMatrix"
- s0
+ s0
(Intercept) .
as.numeric.data3.V2. 4.70011
as.numeric.data3.V3. 7.19943
*/
val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
- assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3)
- assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ assert(modelWithoutIntercept1.intercept ~== 0 absTol 1E-3)
+ assert(modelWithoutIntercept1.weights ~= weightsWithoutInterceptR relTol 1E-3)
+ assert(modelWithoutIntercept2.intercept ~== 0 absTol 1E-3)
+ assert(modelWithoutIntercept2.weights ~= weightsWithoutInterceptR relTol 1E-3)
}
test("linear regression with intercept with L1 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
- val model = trainer.fit(dataset)
+ val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
@@ -160,24 +185,44 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 4.024821
as.numeric.data.V3. 6.679841
*/
- val interceptR = 6.24300
- val weightsR = Vectors.dense(4.024821, 6.679841)
+ val interceptR1 = 6.24300
+ val weightsR1 = Vectors.dense(4.024821, 6.679841)
+
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.416948
+ as.numeric.data.V2. 3.893869
+ as.numeric.data.V3. 6.724286
+ */
+ val interceptR2 = 6.416948
+ val weightsR2 = Vectors.dense(3.893869, 6.724286)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept with L1 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ val trainer1 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
.setFitIntercept(false)
- val model = trainer.fit(dataset)
+ val trainer2 = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ .setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
@@ -189,51 +234,90 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 6.299752
as.numeric.data.V3. 4.772913
*/
- val interceptR = 0.0
- val weightsR = Vectors.dense(6.299752, 4.772913)
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(6.299752, 4.772913)
+
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
+ intercept=FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 6.232193
+ as.numeric.data.V3. 4.764229
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(6.232193, 4.764229)
- assert(model.intercept ~== interceptR absTol 1E-5)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression with intercept with L2 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
- val model = trainer.fit(dataset)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
- weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
- > weights
- 3 x 1 sparse Matrix of class "dgCMatrix"
- s0
- (Intercept) 6.328062
- as.numeric.data.V2. 3.222034
- as.numeric.data.V3. 4.926260
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 5.269376
+ as.numeric.data.V2. 3.736216
+ as.numeric.data.V3. 5.712356)
*/
- val interceptR = 5.269376
- val weightsR = Vectors.dense(3.736216, 5.712356)
+ val interceptR1 = 5.269376
+ val weightsR1 = Vectors.dense(3.736216, 5.712356)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 5.791109
+ as.numeric.data.V2. 3.435466
+ as.numeric.data.V3. 5.910406
+ */
+ val interceptR2 = 5.791109
+ val weightsR2 = Vectors.dense(3.435466, 5.910406)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept with L2 regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
.setFitIntercept(false)
- val model = trainer.fit(dataset)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ .setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
@@ -245,23 +329,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 5.522875
as.numeric.data.V3. 4.214502
*/
- val interceptR = 0.0
- val weightsR = Vectors.dense(5.522875, 4.214502)
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(5.522875, 4.214502)
+
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
+
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
+ intercept = FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.263704
+ as.numeric.data.V3. 4.187419
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(5.263704, 4.187419)
- assert(model.intercept ~== interceptR absTol 1E-3)
- assert(model.weights ~== weightsR relTol 1E-3)
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression with intercept with ElasticNet regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
- val model = trainer.fit(dataset)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
@@ -272,24 +375,43 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 3.168435
as.numeric.data.V3. 5.200403
*/
- val interceptR = 5.696056
- val weightsR = Vectors.dense(3.670489, 6.001122)
+ val interceptR1 = 5.696056
+ val weightsR1 = Vectors.dense(3.670489, 6.001122)
- assert(model.intercept ~== interceptR relTol 1E-3)
- assert(model.weights ~== weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 relTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6
+ standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) 6.114723
+ as.numeric.data.V2. 3.409937
+ as.numeric.data.V3. 6.146531
+ */
+ val interceptR2 = 6.114723
+ val weightsR2 = Vectors.dense(3.409937, 6.146531)
+
+ assert(model2.intercept ~== interceptR2 relTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression without intercept with ElasticNet regularization") {
- val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ val trainer1 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
.setFitIntercept(false)
- val model = trainer.fit(dataset)
+ val trainer2 = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ .setFitIntercept(false).setStandardization(false)
+ val model1 = trainer1.fit(dataset)
+ val model2 = trainer2.fit(dataset)
/*
weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
@@ -301,16 +423,32 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.dataM.V2. 5.673348
as.numeric.dataM.V3. 4.322251
*/
- val interceptR = 0.0
- val weightsR = Vectors.dense(5.673348, 4.322251)
+ val interceptR1 = 0.0
+ val weightsR1 = Vectors.dense(5.673348, 4.322251)
- assert(model.intercept ~== interceptR absTol 1E-3)
- assert(model.weights ~= weightsR relTol 1E-3)
+ assert(model1.intercept ~== interceptR1 absTol 1E-3)
+ assert(model1.weights ~= weightsR1 relTol 1E-3)
- model.transform(dataset).select("features", "prediction").collect().foreach {
+ /*
+ weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
+ intercept=FALSE, standardize=FALSE))
+ > weights
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ (Intercept) .
+ as.numeric.data.V2. 5.477988
+ as.numeric.data.V3. 4.297622
+ */
+ val interceptR2 = 0.0
+ val weightsR2 = Vectors.dense(5.477988, 4.297622)
+
+ assert(model2.intercept ~== interceptR2 absTol 1E-3)
+ assert(model2.weights ~= weightsR2 relTol 1E-3)
+
+ model1.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
- features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ features(0) * model1.weights(0) + features(1) * model1.weights(1) + model1.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
@@ -372,5 +510,4 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
.zip(testSummary.residuals.select("residuals").collect())
.forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
}
-
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index b24ecaa57c89b..7b1b3f11481de 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -26,7 +28,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-
/**
* Test suite for [[RandomForestRegressor]].
*/
@@ -71,6 +72,35 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
regressionTestWithContinuousFeatures(rf)
}
+ test("Feature importance with toy data") {
+ val rf = new RandomForestRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(3)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("all")
+ .setSubsamplingRate(1.0)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = sc.parallelize(Seq(
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)),
+ new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)),
+ new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
+ ))
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val model = rf.fit(df)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+ val importances = model.featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
new file mode 100644
index 0000000000000..dc852795c7f62
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.classification.DecisionTreeClassificationModel
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.ml.tree.{ContinuousSplit, DecisionTreeModel, LeafNode, Node}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.tree.impurity.GiniCalculator
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import RandomForestSuite.mapToVec
+
+ test("computeFeatureImportance, featureImportances") {
+ /* Build tree for testing, with this structure:
+ grandParent
+ left2 parent
+ left right
+ */
+ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+ val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
+
+ val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+ val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
+
+ val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
+ val parentImp = parent.impurityStats
+
+ val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+ val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
+
+ val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
+ val grandImp = grandParent.impurityStats
+
+ // Test feature importance computed at different subtrees.
+ def testNode(node: Node, expected: Map[Int, Double]): Unit = {
+ val map = new OpenHashMap[Int, Double]()
+ RandomForest.computeFeatureImportance(node, map)
+ assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+ }
+
+ // Leaf node
+ testNode(left, Map.empty[Int, Double])
+
+ // Internal node with 2 leaf children
+ val feature0importance = parentImp.calculate() * parentImp.count -
+ (leftImp.calculate() * leftImp.count + rightImp.calculate() * rightImp.count)
+ testNode(parent, Map(0 -> feature0importance))
+
+ // Full tree
+ val feature1importance = grandImp.calculate() * grandImp.count -
+ (left2Imp.calculate() * left2Imp.count + parentImp.calculate() * parentImp.count)
+ testNode(grandParent, Map(0 -> feature0importance, 1 -> feature1importance))
+
+ // Forest consisting of (full tree) + (internal node with 2 leafs)
+ val trees = Array(parent, grandParent).map { root =>
+ new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
+ }
+ val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val tree2norm = feature0importance + feature1importance
+ val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
+ (feature1importance / tree2norm) / 2.0)
+ assert(importances ~== expected relTol 0.01)
+ }
+
+ test("normalizeMapValues") {
+ val map = new OpenHashMap[Int, Double]()
+ map(0) = 1.0
+ map(2) = 2.0
+ RandomForest.normalizeMapValues(map)
+ val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
+ assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
+ }
+
+}
+
+private object RandomForestSuite {
+
+ def mapToVec(map: Map[Int, Double]): Vector = {
+ val size = (map.keys.toSeq :+ 0).max + 1
+ val (indices, values) = map.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(size, indices.toArray, values.toArray)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index db64511a76055..aaca08bb61a45 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
@@ -53,6 +54,10 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setEvaluator(eval)
.setNumFolds(3)
val cvModel = cv.fit(dataset)
+
+ // copied model must have the same paren.
+ MLTestingUtils.checkCopy(cvModel)
+
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
new file mode 100644
index 0000000000000..d290cc9b06e73
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+
+object MLTestingUtils {
+ def checkCopy(model: Model[_]): Unit = {
+ val copied = model.copy(ParamMap.empty)
+ .asInstanceOf[Model[_]]
+ assert(copied.parent.uid == model.parent.uid)
+ assert(copied.parent == model.parent)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index fd653296c9d97..d7b291d5a6330 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 30000
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
// Test if we can accurately learn B for Y = logistic(BX) on streaming data
test("parameter accuracy") {
@@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
}
// apply model training to input stream
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
@@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
// apply model training to input stream, storing the intermediate results
// (we add a count to ensure the result is a DStream)
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B)))
inputDStream.count()
@@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
}
// apply model predictions to test stream
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
@@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
}
// train and predict
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
@@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
.setNumIterations(10)
val numBatches = 10
val emptyInput = Seq.empty[Seq[LabeledPoint]]
- val ssc = setupStreams(emptyInput,
+ ssc = setupStreams(emptyInput,
(inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index b218d72f1268a..b636d02f786e6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("model prediction, parallel and local") {
+ val data = sc.parallelize(GaussianTestData.data)
+ val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+
+ val batchPredictions = gmm.predict(data)
+ batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
+ assert(batchPred === gmm.predict(datum))
+ }
+ }
+
object GaussianTestData {
val data = Array(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index c43e1e575c09c..926185e90bcf9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, max, argmax}
+import java.util.{ArrayList => JArrayList}
+
+import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax}
import org.apache.spark.SparkFunSuite
import org.apache.spark.graphx.Edge
@@ -108,9 +110,42 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
}
+ val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3)))
+ model.topicDistributions.join(top2TopicsPerDoc).collect().foreach {
+ case (docId, (topicDistribution, (indices, weights))) =>
+ assert(indices.length == 2)
+ assert(weights.length == 2)
+ val bdvTopicDist = topicDistribution.toBreeze
+ val top2Indices = argtopk(bdvTopicDist, 2)
+ assert(top2Indices.toArray === indices)
+ assert(bdvTopicDist(top2Indices).toArray === weights)
+ }
+
// Check: log probabilities
assert(model.logLikelihood < 0.0)
assert(model.logPrior < 0.0)
+
+ // Check: topDocumentsPerTopic
+ // Compare it with top documents per topic derived from topicDistributions
+ val topDocsByTopicDistributions = { n: Int =>
+ Range(0, k).map { topic =>
+ val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip
+ (doc.toArray, docWeights.map(_(topic)).toArray)
+ }.toArray
+ }
+
+ // Top 3 documents per topic
+ model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) =>
+ assert(t1._1 === t2._1)
+ assert(t1._2 === t2._2)
+ }
+
+ // All documents per topic
+ val q = tinyCorpus.length
+ model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) =>
+ assert(t1._1 === t2._1)
+ assert(t1._2 === t2._2)
+ }
}
test("vertex indexing") {
@@ -127,8 +162,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
test("setter alias") {
val lda = new LDA().setAlpha(2.0).setBeta(3.0)
- assert(lda.getAlpha.toArray.forall(_ === 2.0))
- assert(lda.getDocConcentration.toArray.forall(_ === 2.0))
+ assert(lda.getAsymmetricAlpha.toArray.forall(_ === 2.0))
+ assert(lda.getAsymmetricDocConcentration.toArray.forall(_ === 2.0))
assert(lda.getBeta === 3.0)
assert(lda.getTopicConcentration === 3.0)
}
@@ -199,16 +234,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("OnlineLDAOptimizer with toy data") {
- def toydata: Array[(Long, Vector)] = Array(
- Vectors.sparse(6, Array(0, 1), Array(1, 1)),
- Vectors.sparse(6, Array(1, 2), Array(1, 1)),
- Vectors.sparse(6, Array(0, 2), Array(1, 1)),
- Vectors.sparse(6, Array(3, 4), Array(1, 1)),
- Vectors.sparse(6, Array(3, 5), Array(1, 1)),
- Vectors.sparse(6, Array(4, 5), Array(1, 1))
- ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
-
- val docs = sc.parallelize(toydata)
+ val docs = sc.parallelize(toyData)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
.setGammaShape(1e10)
val lda = new LDA().setK(2)
@@ -231,30 +257,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
- test("LocalLDAModel logPerplexity") {
- val k = 2
- val vocabSize = 6
- val alpha = 0.01
- val eta = 0.01
- val gammaShape = 100
- // obtained from LDA model trained in gensim, see below
- val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
- 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
- 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
+ test("LocalLDAModel logLikelihood") {
+ val ldaModel: LocalLDAModel = toyModel
- def toydata: Array[(Long, Vector)] = Array(
- Vectors.sparse(6, Array(0, 1), Array(1, 1)),
- Vectors.sparse(6, Array(1, 2), Array(1, 1)),
- Vectors.sparse(6, Array(0, 2), Array(1, 1)),
- Vectors.sparse(6, Array(3, 4), Array(1, 1)),
- Vectors.sparse(6, Array(3, 5), Array(1, 1)),
- Vectors.sparse(6, Array(4, 5), Array(1, 1))
- ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
- val docs = sc.parallelize(toydata)
+ val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1)))
+ .zipWithIndex
+ .map { case (wordCounts, docId) => (docId.toLong, wordCounts) })
+ val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5)))
+ .zipWithIndex
+ .map { case (wordCounts, docId) => (docId.toLong, wordCounts) })
+ /* Verify results using gensim:
+ import numpy as np
+ from gensim import models
+ corpus = [
+ [(0, 1.0), (1, 1.0)],
+ [(1, 1.0), (2, 1.0)],
+ [(0, 1.0), (2, 1.0)],
+ [(3, 1.0), (4, 1.0)],
+ [(3, 1.0), (5, 1.0)],
+ [(4, 1.0), (5, 1.0)]]
+ np.random.seed(2345)
+ lda = models.ldamodel.LdaModel(
+ corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100,
+ decay=0.51, offset=1024)
+ docsSingleWord = [[(0, 1.0)]]
+ docsRepeatedWord = [[(0, 5.0)]]
+ print(lda.bound(docsSingleWord))
+ > -25.9706969833
+ print(lda.bound(docsRepeatedWord))
+ > -31.4413908227
+ */
- val ldaModel: LocalLDAModel = new LocalLDAModel(
- topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
+ assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D)
+ assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441 relTol 1E-3D)
+ }
+
+ test("LocalLDAModel logPerplexity") {
+ val docs = sc.parallelize(toyData)
+ val ldaModel: LocalLDAModel = toyModel
/* Verify results using gensim:
import numpy as np
@@ -274,32 +315,13 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
> -3.69051285096
*/
- assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D)
+ // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition
+ assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D)
}
test("LocalLDAModel predict") {
- val k = 2
- val vocabSize = 6
- val alpha = 0.01
- val eta = 0.01
- val gammaShape = 100
- // obtained from LDA model trained in gensim, see below
- val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
- 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
- 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
-
- def toydata: Array[(Long, Vector)] = Array(
- Vectors.sparse(6, Array(0, 1), Array(1, 1)),
- Vectors.sparse(6, Array(1, 2), Array(1, 1)),
- Vectors.sparse(6, Array(0, 2), Array(1, 1)),
- Vectors.sparse(6, Array(3, 4), Array(1, 1)),
- Vectors.sparse(6, Array(3, 5), Array(1, 1)),
- Vectors.sparse(6, Array(4, 5), Array(1, 1))
- ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
- val docs = sc.parallelize(toydata)
-
- val ldaModel: LocalLDAModel = new LocalLDAModel(
- topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
+ val docs = sc.parallelize(toyData)
+ val ldaModel: LocalLDAModel = toyModel
/* Verify results using gensim:
import numpy as np
@@ -340,16 +362,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("OnlineLDAOptimizer with asymmetric prior") {
- def toydata: Array[(Long, Vector)] = Array(
- Vectors.sparse(6, Array(0, 1), Array(1, 1)),
- Vectors.sparse(6, Array(1, 2), Array(1, 1)),
- Vectors.sparse(6, Array(0, 2), Array(1, 1)),
- Vectors.sparse(6, Array(3, 4), Array(1, 1)),
- Vectors.sparse(6, Array(3, 5), Array(1, 1)),
- Vectors.sparse(6, Array(4, 5), Array(1, 1))
- ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
-
- val docs = sc.parallelize(toydata)
+ val docs = sc.parallelize(toyData)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
.setGammaShape(1e10)
val lda = new LDA().setK(2)
@@ -389,6 +402,40 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("OnlineLDAOptimizer alpha hyperparameter optimization") {
+ val k = 2
+ val docs = sc.parallelize(toyData)
+ val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
+ .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false)
+ val lda = new LDA().setK(k)
+ .setDocConcentration(1D / k)
+ .setTopicConcentration(0.01)
+ .setMaxIterations(100)
+ .setOptimizer(op)
+ .setSeed(12345)
+ val ldaModel: LocalLDAModel = lda.run(docs).asInstanceOf[LocalLDAModel]
+
+ /* Verify the results with gensim:
+ import numpy as np
+ from gensim import models
+ corpus = [
+ [(0, 1.0), (1, 1.0)],
+ [(1, 1.0), (2, 1.0)],
+ [(0, 1.0), (2, 1.0)],
+ [(3, 1.0), (4, 1.0)],
+ [(3, 1.0), (5, 1.0)],
+ [(4, 1.0), (5, 1.0)]]
+ np.random.seed(2345)
+ lda = models.ldamodel.LdaModel(
+ corpus=corpus, alpha='auto', eta=0.01, num_topics=2, update_every=0, passes=100,
+ decay=0.51, offset=1024)
+ print(lda.alpha)
+ > [ 0.42582646 0.43511073]
+ */
+
+ assert(ldaModel.docConcentration ~== Vectors.dense(0.42582646, 0.43511073) absTol 0.05)
+ }
+
test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics,
@@ -520,4 +567,38 @@ private[clustering] object LDASuite {
def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter {
case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0
}
+
+ def toyData: Array[(Long, Vector)] = Array(
+ Vectors.sparse(6, Array(0, 1), Array(1, 1)),
+ Vectors.sparse(6, Array(1, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(0, 2), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 4), Array(1, 1)),
+ Vectors.sparse(6, Array(3, 5), Array(1, 1)),
+ Vectors.sparse(6, Array(4, 5), Array(1, 1))
+ ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+
+ /** Used in the Java Test Suite */
+ def javaToyData: JArrayList[(java.lang.Long, Vector)] = {
+ val javaData = new JArrayList[(java.lang.Long, Vector)]
+ var i = 0
+ while (i < toyData.size) {
+ javaData.add((toyData(i)._1, toyData(i)._2))
+ i += 1
+ }
+ javaData
+ }
+
+ def toyModel: LocalLDAModel = {
+ val k = 2
+ val vocabSize = 6
+ val alpha = 0.01
+ val eta = 0.01
+ val gammaShape = 100
+ val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array(
+ 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597,
+ 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124))
+ val ldaModel: LocalLDAModel = new LocalLDAModel(
+ topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape)
+ ldaModel
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index ac01622b8a089..3645d29dccdb2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.random.XORShiftRandom
@@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
override def maxWaitTimeMillis: Int = 30000
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
test("accuracy for single center and equivalence to grand average") {
// set parameters
val numBatches = 10
@@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
// setup and run the model training
- val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
@@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
// setup and run the model training
- val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
kMeans.trainOn(inputDStream)
inputDStream.count()
})
@@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
// setup and run the model training
- val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
kMeans.trainOn(inputDStream)
inputDStream.count()
})
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 6dd2dc926acc5..a83e543859b8a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("PrefixSpan using Integer type") {
+ test("PrefixSpan internal (integer seq, 0 delim) run, singleton itemsets") {
/*
library("arulesSequences")
@@ -35,79 +35,345 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
val sequences = Array(
- Array(1, 3, 4, 5),
- Array(2, 3, 1),
- Array(2, 4, 1),
- Array(3, 1, 3, 4, 5),
- Array(3, 4, 4, 3),
- Array(6, 5, 3))
+ Array(0, 1, 0, 3, 0, 4, 0, 5, 0),
+ Array(0, 2, 0, 3, 0, 1, 0),
+ Array(0, 2, 0, 4, 0, 1, 0),
+ Array(0, 3, 0, 1, 0, 3, 0, 4, 0, 5, 0),
+ Array(0, 3, 0, 4, 0, 4, 0, 3, 0),
+ Array(0, 6, 0, 5, 0, 3, 0))
val rdd = sc.parallelize(sequences, 2).cache()
- val prefixspan = new PrefixSpan()
- .setMinSupport(0.33)
- .setMaxPatternLength(50)
- val result1 = prefixspan.run(rdd)
+ val result1 = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L)
val expectedValue1 = Array(
- (Array(1), 4L),
- (Array(1, 3), 2L),
- (Array(1, 3, 4), 2L),
- (Array(1, 3, 4, 5), 2L),
- (Array(1, 3, 5), 2L),
- (Array(1, 4), 2L),
- (Array(1, 4, 5), 2L),
- (Array(1, 5), 2L),
- (Array(2), 2L),
- (Array(2, 1), 2L),
- (Array(3), 5L),
- (Array(3, 1), 2L),
- (Array(3, 3), 2L),
- (Array(3, 4), 3L),
- (Array(3, 4, 5), 2L),
- (Array(3, 5), 2L),
- (Array(4), 4L),
- (Array(4, 5), 2L),
- (Array(5), 3L)
+ (Array(0, 1, 0), 4L),
+ (Array(0, 1, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 4, 0, 5, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 5, 0), 2L),
+ (Array(0, 1, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 4, 0, 5, 0), 2L),
+ (Array(0, 1, 0, 5, 0), 2L),
+ (Array(0, 2, 0), 2L),
+ (Array(0, 2, 0, 1, 0), 2L),
+ (Array(0, 3, 0), 5L),
+ (Array(0, 3, 0, 1, 0), 2L),
+ (Array(0, 3, 0, 3, 0), 2L),
+ (Array(0, 3, 0, 4, 0), 3L),
+ (Array(0, 3, 0, 4, 0, 5, 0), 2L),
+ (Array(0, 3, 0, 5, 0), 2L),
+ (Array(0, 4, 0), 4L),
+ (Array(0, 4, 0, 5, 0), 2L),
+ (Array(0, 5, 0), 3L)
)
- assert(compareResults(expectedValue1, result1.collect()))
+ compareInternalResults(expectedValue1, result1.collect())
- prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
- val result2 = prefixspan.run(rdd)
+ val result2 = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 3, maxPatternLength = 50, maxLocalProjDBSize = 32L)
val expectedValue2 = Array(
- (Array(1), 4L),
- (Array(3), 5L),
- (Array(3, 4), 3L),
- (Array(4), 4L),
- (Array(5), 3L)
+ (Array(0, 1, 0), 4L),
+ (Array(0, 3, 0), 5L),
+ (Array(0, 3, 0, 4, 0), 3L),
+ (Array(0, 4, 0), 4L),
+ (Array(0, 5, 0), 3L)
)
- assert(compareResults(expectedValue2, result2.collect()))
+ compareInternalResults(expectedValue2, result2.collect())
- prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
- val result3 = prefixspan.run(rdd)
+ val result3 = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 2, maxPatternLength = 2, maxLocalProjDBSize = 32L)
val expectedValue3 = Array(
- (Array(1), 4L),
- (Array(1, 3), 2L),
- (Array(1, 4), 2L),
- (Array(1, 5), 2L),
- (Array(2, 1), 2L),
- (Array(2), 2L),
- (Array(3), 5L),
- (Array(3, 1), 2L),
- (Array(3, 3), 2L),
- (Array(3, 4), 3L),
- (Array(3, 5), 2L),
- (Array(4), 4L),
- (Array(4, 5), 2L),
- (Array(5), 3L)
+ (Array(0, 1, 0), 4L),
+ (Array(0, 1, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 5, 0), 2L),
+ (Array(0, 2, 0, 1, 0), 2L),
+ (Array(0, 2, 0), 2L),
+ (Array(0, 3, 0), 5L),
+ (Array(0, 3, 0, 1, 0), 2L),
+ (Array(0, 3, 0, 3, 0), 2L),
+ (Array(0, 3, 0, 4, 0), 3L),
+ (Array(0, 3, 0, 5, 0), 2L),
+ (Array(0, 4, 0), 4L),
+ (Array(0, 4, 0, 5, 0), 2L),
+ (Array(0, 5, 0), 3L)
)
- assert(compareResults(expectedValue3, result3.collect()))
+ compareInternalResults(expectedValue3, result3.collect())
}
- private def compareResults(
- expectedValue: Array[(Array[Int], Long)],
- actualValue: Array[(Array[Int], Long)]): Boolean = {
- expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
- actualValue.map(x => (x._1.toSeq, x._2)).toSet
+ test("PrefixSpan internal (integer seq, -1 delim) run, variable-size itemsets") {
+ val sequences = Array(
+ Array(0, 1, 0, 1, 2, 3, 0, 1, 3, 0, 4, 0, 3, 6, 0),
+ Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0),
+ Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0),
+ Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0))
+ val rdd = sc.parallelize(sequences, 2).cache()
+ val result = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L)
+
+ /*
+ To verify results, create file "prefixSpanSeqs" with content
+ (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
+ 1 1 1 1
+ 1 2 3 1 2 3
+ 1 3 2 1 3
+ 1 4 1 4
+ 1 5 2 3 6
+ 2 1 2 1 4
+ 2 2 1 3
+ 2 3 2 2 3
+ 2 4 2 1 5
+ 3 1 2 5 6
+ 3 2 2 1 2
+ 3 3 2 4 6
+ 3 4 1 3
+ 3 5 1 2
+ 4 1 1 5
+ 4 2 1 7
+ 4 3 2 1 6
+ 4 4 1 3
+ 4 5 1 2
+ 4 6 1 3
+ In R, run:
+ library("arulesSequences")
+ prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
+ freqItemSeq = cspade(prefixSpanSeqs,
+ parameter = list(support = 0.5, maxlen = 5 ))
+ resSeq = as(freqItemSeq, "data.frame")
+ resSeq
+
+ sequence support
+ 1 <{1}> 1.00
+ 2 <{2}> 1.00
+ 3 <{3}> 1.00
+ 4 <{4}> 0.75
+ 5 <{5}> 0.75
+ 6 <{6}> 0.75
+ 7 <{1},{6}> 0.50
+ 8 <{2},{6}> 0.50
+ 9 <{5},{6}> 0.50
+ 10 <{1,2},{6}> 0.50
+ 11 <{1},{4}> 0.50
+ 12 <{2},{4}> 0.50
+ 13 <{1,2},{4}> 0.50
+ 14 <{1},{3}> 1.00
+ 15 <{2},{3}> 0.75
+ 16 <{2,3}> 0.50
+ 17 <{3},{3}> 0.75
+ 18 <{4},{3}> 0.75
+ 19 <{5},{3}> 0.50
+ 20 <{6},{3}> 0.50
+ 21 <{5},{6},{3}> 0.50
+ 22 <{6},{2},{3}> 0.50
+ 23 <{5},{2},{3}> 0.50
+ 24 <{5},{1},{3}> 0.50
+ 25 <{2},{4},{3}> 0.50
+ 26 <{1},{4},{3}> 0.50
+ 27 <{1,2},{4},{3}> 0.50
+ 28 <{1},{3},{3}> 0.75
+ 29 <{1,2},{3}> 0.50
+ 30 <{1},{2},{3}> 0.50
+ 31 <{1},{2,3}> 0.50
+ 32 <{1},{2}> 1.00
+ 33 <{1,2}> 0.50
+ 34 <{3},{2}> 0.75
+ 35 <{4},{2}> 0.50
+ 36 <{5},{2}> 0.50
+ 37 <{6},{2}> 0.50
+ 38 <{5},{6},{2}> 0.50
+ 39 <{6},{3},{2}> 0.50
+ 40 <{5},{3},{2}> 0.50
+ 41 <{5},{1},{2}> 0.50
+ 42 <{4},{3},{2}> 0.50
+ 43 <{1},{3},{2}> 0.75
+ 44 <{5},{6},{3},{2}> 0.50
+ 45 <{5},{1},{3},{2}> 0.50
+ 46 <{1},{1}> 0.50
+ 47 <{2},{1}> 0.50
+ 48 <{3},{1}> 0.50
+ 49 <{5},{1}> 0.50
+ 50 <{2,3},{1}> 0.50
+ 51 <{1},{3},{1}> 0.50
+ 52 <{1},{2,3},{1}> 0.50
+ 53 <{1},{2},{1}> 0.50
+ */
+ val expectedValue = Array(
+ (Array(0, 1, 0), 4L),
+ (Array(0, 2, 0), 4L),
+ (Array(0, 3, 0), 4L),
+ (Array(0, 4, 0), 3L),
+ (Array(0, 5, 0), 3L),
+ (Array(0, 6, 0), 3L),
+ (Array(0, 1, 0, 6, 0), 2L),
+ (Array(0, 2, 0, 6, 0), 2L),
+ (Array(0, 5, 0, 6, 0), 2L),
+ (Array(0, 1, 2, 0, 6, 0), 2L),
+ (Array(0, 1, 0, 4, 0), 2L),
+ (Array(0, 2, 0, 4, 0), 2L),
+ (Array(0, 1, 2, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 3, 0), 4L),
+ (Array(0, 2, 0, 3, 0), 3L),
+ (Array(0, 2, 3, 0), 2L),
+ (Array(0, 3, 0, 3, 0), 3L),
+ (Array(0, 4, 0, 3, 0), 3L),
+ (Array(0, 5, 0, 3, 0), 2L),
+ (Array(0, 6, 0, 3, 0), 2L),
+ (Array(0, 5, 0, 6, 0, 3, 0), 2L),
+ (Array(0, 6, 0, 2, 0, 3, 0), 2L),
+ (Array(0, 5, 0, 2, 0, 3, 0), 2L),
+ (Array(0, 5, 0, 1, 0, 3, 0), 2L),
+ (Array(0, 2, 0, 4, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 4, 0, 3, 0), 2L),
+ (Array(0, 1, 2, 0, 4, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 3, 0), 3L),
+ (Array(0, 1, 2, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 2, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 2, 3, 0), 2L),
+ (Array(0, 1, 0, 2, 0), 4L),
+ (Array(0, 1, 2, 0), 2L),
+ (Array(0, 3, 0, 2, 0), 3L),
+ (Array(0, 4, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 2, 0), 2L),
+ (Array(0, 6, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 6, 0, 2, 0), 2L),
+ (Array(0, 6, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 1, 0, 2, 0), 2L),
+ (Array(0, 4, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 2, 0), 3L),
+ (Array(0, 5, 0, 6, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 1, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 1, 0, 1, 0), 2L),
+ (Array(0, 2, 0, 1, 0), 2L),
+ (Array(0, 3, 0, 1, 0), 2L),
+ (Array(0, 5, 0, 1, 0), 2L),
+ (Array(0, 2, 3, 0, 1, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 1, 0), 2L),
+ (Array(0, 1, 0, 2, 3, 0, 1, 0), 2L),
+ (Array(0, 1, 0, 2, 0, 1, 0), 2L))
+
+ compareInternalResults(expectedValue, result.collect())
}
+ test("PrefixSpan projections with multiple partial starts") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(1, 2, 3)))
+ val rdd = sc.parallelize(sequences, 2)
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(1.0)
+ .setMaxPatternLength(2)
+ val model = prefixSpan.run(rdd)
+ val expected = Array(
+ (Array(Array(1)), 1L),
+ (Array(Array(1, 2)), 1L),
+ (Array(Array(1), Array(1)), 1L),
+ (Array(Array(1), Array(2)), 1L),
+ (Array(Array(1), Array(3)), 1L),
+ (Array(Array(1, 3)), 1L),
+ (Array(Array(2)), 1L),
+ (Array(Array(2, 3)), 1L),
+ (Array(Array(2), Array(1)), 1L),
+ (Array(Array(2), Array(2)), 1L),
+ (Array(Array(2), Array(3)), 1L),
+ (Array(Array(3)), 1L))
+ compareResults(expected, model.freqSequences.collect())
+ }
+
+ test("PrefixSpan Integer type, variable-size itemsets") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(3)),
+ Array(Array(1), Array(3, 2), Array(1, 2)),
+ Array(Array(1, 2), Array(5)),
+ Array(Array(6)))
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+
+ /*
+ To verify results, create file "prefixSpanSeqs2" with content
+ (format = (transactionID, idxInTransaction, numItemsinItemset, itemset)):
+ 1 1 2 1 2
+ 1 2 1 3
+ 2 1 1 1
+ 2 2 2 3 2
+ 2 3 2 1 2
+ 3 1 2 1 2
+ 3 2 1 5
+ 4 1 1 6
+ In R, run:
+ library("arulesSequences")
+ prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
+ freqItemSeq = cspade(prefixSpanSeqs,
+ parameter = 0.5, maxlen = 5 ))
+ resSeq = as(freqItemSeq, "data.frame")
+ resSeq
+
+ sequence support
+ 1 <{1}> 0.75
+ 2 <{2}> 0.75
+ 3 <{3}> 0.50
+ 4 <{1},{3}> 0.50
+ 5 <{1,2}> 0.75
+ */
+
+ val model = prefixSpan.run(rdd)
+ val expected = Array(
+ (Array(Array(1)), 3L),
+ (Array(Array(2)), 3L),
+ (Array(Array(3)), 2L),
+ (Array(Array(1), Array(3)), 2L),
+ (Array(Array(1, 2)), 3L)
+ )
+ compareResults(expected, model.freqSequences.collect())
+ }
+
+ test("PrefixSpan String type, variable-size itemsets") {
+ // This is the same test as "PrefixSpan Int type, variable-size itemsets" except
+ // mapped to Strings
+ val intToString = (1 to 6).zip(Seq("a", "b", "c", "d", "e", "f")).toMap
+ val sequences = Seq(
+ Array(Array(1, 2), Array(3)),
+ Array(Array(1), Array(3, 2), Array(1, 2)),
+ Array(Array(1, 2), Array(5)),
+ Array(Array(6))).map(seq => seq.map(itemSet => itemSet.map(intToString)))
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+
+ val model = prefixSpan.run(rdd)
+ val expected = Array(
+ (Array(Array(1)), 3L),
+ (Array(Array(2)), 3L),
+ (Array(Array(3)), 2L),
+ (Array(Array(1), Array(3)), 2L),
+ (Array(Array(1, 2)), 3L)
+ ).map { case (pattern, count) =>
+ (pattern.map(itemSet => itemSet.map(intToString)), count)
+ }
+ compareResults(expected, model.freqSequences.collect())
+ }
+
+ private def compareResults[Item](
+ expectedValue: Array[(Array[Array[Item]], Long)],
+ actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {
+ val expectedSet = expectedValue.map { case (pattern: Array[Array[Item]], count: Long) =>
+ (pattern.map(itemSet => itemSet.toSet).toSeq, count)
+ }.toSet
+ val actualSet = actualValue.map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ assert(expectedSet === actualSet)
+ }
+
+ private def compareInternalResults(
+ expectedValue: Array[(Array[Int], Long)],
+ actualValue: Array[(Array[Int], Long)]): Unit = {
+ val expectedSet = expectedValue.map(x => (x._1.toSeq, x._2)).toSet
+ val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet
+ assert(expectedSet === actualSet)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index a270ba2562db9..bfd6d5495f5e0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite {
}
}
+ test("equals") {
+ val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0))
+ assert(dm1 === dm1)
+ assert(dm1 !== dm1.transpose)
+
+ val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0))
+ assert(dm1 === dm2.transpose)
+
+ val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse
+ assert(sm1 === sm1)
+ assert(sm1 === dm1)
+ assert(sm1 !== sm1.transpose)
+
+ val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse
+ assert(sm1 === sm2.transpose)
+ assert(sm1 === dm2.transpose)
+ }
+
test("matrix copies are deep copies") {
val m = 3
val n = 2
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 1c37ea5123e82..6508ddeba4206 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -367,4 +367,11 @@ class VectorsSuite extends SparkFunSuite with Logging {
val sv1c = sv1.compressed.asInstanceOf[DenseVector]
assert(sv1 === sv1c)
}
+
+ test("SparseVector.slice") {
+ val v = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4))
+ assert(v.slice(Array(0, 2)) === new SparseVector(2, Array(1), Array(2.2)))
+ assert(v.slice(Array(2, 0)) === new SparseVector(2, Array(0), Array(2.2)))
+ assert(v.slice(Array(2, 0, 3, 4)) === new SparseVector(4, Array(0, 3), Array(2.2, 4.4)))
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index a2a4c5f6b8b70..34c07ed170816 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.TestSuiteBase
class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// use longer wait time to ensure job completion
override def maxWaitTimeMillis: Int = 20000
+ var ssc: StreamingContext = _
+
+ override def afterFunction() {
+ super.afterFunction()
+ if (ssc != null) {
+ ssc.stop()
+ }
+ }
+
// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
def errorMessage = v1.toString + " did not equal " + v2.toString
@@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
}
// apply model training to input stream
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.count()
})
@@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
// apply model training to input stream, storing the intermediate results
// (we add a count to ensure the result is a DStream)
- val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
inputDStream.count()
@@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
}
// apply model predictions to test stream
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
// collect the output as (true, estimated) tuples
@@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
}
// train and predict
- val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+ ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
@@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
val numBatches = 10
val nPoints = 100
val emptyInput = Seq.empty[Seq[LabeledPoint]]
- val ssc = setupStreams(emptyInput,
+ ssc = setupStreams(emptyInput,
(inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 2521b3342181a..6fc9e8df621df 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss)
- (algos zip losses) map {
- case (algo, loss) => {
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
+ algos.zip(losses).foreach { case (algo, loss) =>
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ .runWithValidation(trainRdd, validateRdd)
+ val numTrees = gbtValidate.numTrees
+ assert(numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+ } else {
+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
}
}
+ assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
}
}
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ sc.setCheckpointDir(path)
+
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
+ val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
+
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+ sc.checkpointDir = None
+ Utils.deleteRecursively(tempDir)
+ }
+
}
private object GradientBoostedTreesSuite {
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 37f2e34ceb24d..e8e7f06247d3e 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,6 +19,7 @@
import java.io.Closeable;
import java.io.IOException;
+import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@@ -79,6 +80,10 @@ public boolean isActive() {
return channel.isOpen() || channel.isActive();
}
+ public SocketAddress getSocketAddress() {
+ return channel.remoteAddress();
+ }
+
/**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index f76bb49e874fc..f0363830b61ac 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -52,6 +52,11 @@ public static ChunkFetchFailure decode(ByteBuf buf) {
return new ChunkFetchFailure(streamChunkId, errorString);
}
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamChunkId, errorString);
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchFailure) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
index 980947cf13f6b..5a173af54f618 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -48,6 +48,11 @@ public static ChunkFetchRequest decode(ByteBuf buf) {
return new ChunkFetchRequest(StreamChunkId.decode(buf));
}
+ @Override
+ public int hashCode() {
+ return streamChunkId.hashCode();
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchRequest) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index ff4936470c697..c962fb7ecf76d 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -61,6 +61,11 @@ public static ChunkFetchSuccess decode(ByteBuf buf) {
return new ChunkFetchSuccess(streamChunkId, managedBuf);
}
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamChunkId, buffer);
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchSuccess) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index 6b991375fc486..2dfc7876ba328 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -50,6 +50,11 @@ public static RpcFailure decode(ByteBuf buf) {
return new RpcFailure(requestId, errorString);
}
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, errorString);
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof RpcFailure) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index cdee0b0e0316b..745039db742fa 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -59,6 +59,11 @@ public static RpcRequest decode(ByteBuf buf) {
return new RpcRequest(requestId, message);
}
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, Arrays.hashCode(message));
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof RpcRequest) {
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index 0a62e09a8115c..1671cd444f039 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -50,6 +50,11 @@ public static RpcResponse decode(ByteBuf buf) {
return new RpcResponse(requestId, response);
}
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(requestId, Arrays.hashCode(response));
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof RpcResponse) {
diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
index 38113a918f795..83c90f9eff2b1 100644
--- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
+++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
@@ -80,6 +80,11 @@ public Object convertToNetty() throws IOException {
return underlying.convertToNetty();
}
+ @Override
+ public int hashCode() {
+ return underlying.hashCode();
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof ManagedBuffer) {
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index be6632bb8cf49..8104004847a24 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -17,11 +17,11 @@
package org.apache.spark.network.sasl;
-import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.io.File;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
@@ -138,8 +138,8 @@ private void testBasicSasl(boolean encrypt) throws Exception {
public Void answer(InvocationOnMock invocation) {
byte[] message = (byte[]) invocation.getArguments()[1];
RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
- assertEquals("Ping", new String(message, UTF_8));
- cb.onSuccess("Pong".getBytes(UTF_8));
+ assertEquals("Ping", new String(message, StandardCharsets.UTF_8));
+ cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8));
return null;
}
})
@@ -148,8 +148,9 @@ public Void answer(InvocationOnMock invocation) {
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
- byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
- assertEquals("Pong", new String(response, UTF_8));
+ byte[] response = ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
+ TimeUnit.SECONDS.toMillis(10));
+ assertEquals("Pong", new String(response, StandardCharsets.UTF_8));
} finally {
ctx.close();
}
@@ -235,7 +236,7 @@ public void testFileRegionEncryption() throws Exception {
final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
System.setProperty(blockSizeConf, "1k");
- final AtomicReference response = new AtomicReference();
+ final AtomicReference response = new AtomicReference<>();
final File file = File.createTempFile("sasltest", ".txt");
SaslTestCtx ctx = null;
try {
@@ -321,7 +322,8 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception {
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
- ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
+ ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
+ TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
} catch (Exception e) {
assertFalse(e.getCause() instanceof TimeoutException);
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index e4faaf8854fc7..db9dc4f17cee9 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -65,7 +65,13 @@ public ExternalShuffleBlockHandler(TransportConf conf) {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
+ handleMessage(msgObj, client, callback);
+ }
+ protected void handleMessage(
+ BlockTransferMessage msgObj,
+ TransportClient client,
+ RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
OpenBlocks msg = (OpenBlocks) msgObj;
List blocks = Lists.newArrayList();
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 612bce571a493..ea6d248d66be3 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -50,8 +50,8 @@ public class ExternalShuffleClient extends ShuffleClient {
private final boolean saslEncryptionEnabled;
private final SecretKeyHolder secretKeyHolder;
- private TransportClientFactory clientFactory;
- private String appId;
+ protected TransportClientFactory clientFactory;
+ protected String appId;
/**
* Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled,
@@ -71,6 +71,10 @@ public ExternalShuffleClient(
this.saslEncryptionEnabled = saslEncryptionEnabled;
}
+ protected void checkInit() {
+ assert appId != null : "Called before init()";
+ }
+
@Override
public void init(String appId) {
this.appId = appId;
@@ -89,7 +93,7 @@ public void fetchBlocks(
final String execId,
String[] blockIds,
BlockFetchingListener listener) {
- assert appId != null : "Called before init()";
+ checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
@@ -132,7 +136,7 @@ public void registerWithShuffleServer(
int port,
String execId,
ExecutorShuffleInfo executorInfo) throws IOException {
- assert appId != null : "Called before init()";
+ checkInit();
TransportClient client = clientFactory.createClient(host, port);
byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
new file mode 100644
index 0000000000000..7543b6be4f2a1
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.mesos;
+
+import java.io.IOException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.sasl.SecretKeyHolder;
+import org.apache.spark.network.shuffle.ExternalShuffleClient;
+import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A client for talking to the external shuffle service in Mesos coarse-grained mode.
+ *
+ * This is used by the Spark driver to register with each external shuffle service on the cluster.
+ * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably
+ * after the application exits. Mesos does not provide a great alternative to do this, so Spark
+ * has to detect this itself.
+ */
+public class MesosExternalShuffleClient extends ExternalShuffleClient {
+ private final Logger logger = LoggerFactory.getLogger(MesosExternalShuffleClient.class);
+
+ /**
+ * Creates an Mesos external shuffle client that wraps the {@link ExternalShuffleClient}.
+ * Please refer to docs on {@link ExternalShuffleClient} for more information.
+ */
+ public MesosExternalShuffleClient(
+ TransportConf conf,
+ SecretKeyHolder secretKeyHolder,
+ boolean saslEnabled,
+ boolean saslEncryptionEnabled) {
+ super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled);
+ }
+
+ public void registerDriverWithShuffleService(String host, int port) throws IOException {
+ checkInit();
+ byte[] registerDriver = new RegisterDriver(appId).toByteArray();
+ TransportClient client = clientFactory.createClient(host, port);
+ client.sendRpc(registerDriver, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ logger.info("Successfully registered app " + appId + " with external shuffle service.");
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.warn("Unable to register app " + appId + " with external shuffle service. " +
+ "Please manually remove shuffle data after driver exit. Error: " + e);
+ }
+ });
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index 6c1210b33268a..fcb52363e632c 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -21,6 +21,7 @@
import io.netty.buffer.Unpooled;
import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver;
/**
* Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or
@@ -37,7 +38,7 @@ public abstract class BlockTransferMessage implements Encodable {
/** Preceding every serialized message is its type, which allows us to deserialize it. */
public static enum Type {
- OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3);
+ OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4);
private final byte id;
@@ -60,6 +61,7 @@ public static BlockTransferMessage fromByteArray(byte[] msg) {
case 1: return UploadBlock.decode(buf);
case 2: return RegisterExecutor.decode(buf);
case 3: return StreamHandle.decode(buf);
+ case 4: return RegisterDriver.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
new file mode 100644
index 0000000000000..94a61d6caadc4
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol.mesos;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+/**
+ * A message sent from the driver to register with the MesosExternalShuffleService.
+ */
+public class RegisterDriver extends BlockTransferMessage {
+ private final String appId;
+
+ public RegisterDriver(String appId) {
+ this.appId = appId;
+ }
+
+ public String getAppId() { return appId; }
+
+ @Override
+ protected Type type() { return Type.REGISTER_DRIVER; }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId);
+ }
+
+ public static RegisterDriver decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ return new RegisterDriver(appId);
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index 73374cdc77a23..1d197497b7c8f 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -90,9 +90,11 @@ public void testOpenShuffleBlocks() {
(StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue());
assertEquals(2, handle.numChunks);
- ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class);
+ @SuppressWarnings("unchecked")
+ ArgumentCaptor> stream = (ArgumentCaptor>)
+ (ArgumentCaptor>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(stream.capture());
- Iterator buffers = (Iterator) stream.getValue();
+ Iterator buffers = stream.getValue();
assertEquals(block0Marker, buffers.next());
assertEquals(block1Marker, buffers.next());
assertFalse(buffers.hasNext());
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
index 1ad0d72ae5ec5..06e46f9241094 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
@@ -20,7 +20,9 @@
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.Arrays;
import java.util.LinkedHashSet;
+import java.util.List;
import java.util.Map;
import com.google.common.collect.ImmutableMap;
@@ -67,13 +69,13 @@ public void afterEach() {
public void testNoFailures() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// Immediately return both blocks successfully.
ImmutableMap.builder()
.put("b0", block0)
.put("b1", block1)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -86,13 +88,13 @@ public void testNoFailures() throws IOException {
public void testUnrecoverableFailure() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// b0 throws a non-IOException error, so it will be failed without retry.
ImmutableMap.builder()
.put("b0", new RuntimeException("Ouch!"))
.put("b1", block1)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -105,7 +107,7 @@ public void testUnrecoverableFailure() throws IOException {
public void testSingleIOExceptionOnFirst() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// IOException will cause a retry. Since b0 fails, we will retry both.
ImmutableMap.builder()
.put("b0", new IOException("Connection failed or something"))
@@ -114,8 +116,8 @@ public void testSingleIOExceptionOnFirst() throws IOException {
ImmutableMap.builder()
.put("b0", block0)
.put("b1", block1)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -128,7 +130,7 @@ public void testSingleIOExceptionOnFirst() throws IOException {
public void testSingleIOExceptionOnSecond() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// IOException will cause a retry. Since b1 fails, we will not retry b0.
ImmutableMap.builder()
.put("b0", block0)
@@ -136,8 +138,8 @@ public void testSingleIOExceptionOnSecond() throws IOException {
.build(),
ImmutableMap.builder()
.put("b1", block1)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -150,7 +152,7 @@ public void testSingleIOExceptionOnSecond() throws IOException {
public void testTwoIOExceptions() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// b0's IOException will trigger retry, b1's will be ignored.
ImmutableMap.builder()
.put("b0", new IOException())
@@ -164,8 +166,8 @@ public void testTwoIOExceptions() throws IOException {
// b1 returns successfully within 2 retries.
ImmutableMap.builder()
.put("b1", block1)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -178,7 +180,7 @@ public void testTwoIOExceptions() throws IOException {
public void testThreeIOExceptions() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// b0's IOException will trigger retry, b1's will be ignored.
ImmutableMap.builder()
.put("b0", new IOException())
@@ -196,8 +198,8 @@ public void testThreeIOExceptions() throws IOException {
// This is not reached -- b1 has failed.
ImmutableMap.builder()
.put("b1", block1)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -210,7 +212,7 @@ public void testThreeIOExceptions() throws IOException {
public void testRetryAndUnrecoverable() throws IOException {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
- Map[] interactions = new Map[] {
+ List extends Map> interactions = Arrays.asList(
// b0's IOException will trigger retry, subsequent messages will be ignored.
ImmutableMap.builder()
.put("b0", new IOException())
@@ -226,8 +228,8 @@ public void testRetryAndUnrecoverable() throws IOException {
// b2 succeeds in its last retry.
ImmutableMap.builder()
.put("b2", block2)
- .build(),
- };
+ .build()
+ );
performInteractions(interactions, listener);
@@ -248,7 +250,8 @@ public void testRetryAndUnrecoverable() throws IOException {
* subset of the original blocks in a second interaction.
*/
@SuppressWarnings("unchecked")
- private void performInteractions(final Map[] interactions, BlockFetchingListener listener)
+ private static void performInteractions(List extends Map> interactions,
+ BlockFetchingListener listener)
throws IOException {
TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
diff --git a/pom.xml b/pom.xml
index 35fc8c44bc1b0..cfd7d32563f2a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -59,7 +59,7 @@
- 3.0.4
+ ${maven.version}
@@ -104,6 +104,7 @@
external/flume-sinkexternal/flume-assemblyexternal/mqtt
+ external/mqtt-assemblyexternal/zeromqexamplesrepl
@@ -118,6 +119,7 @@
com.typesafe.akka2.3.111.7
+ 3.3.3spark0.21.1shaded-protobuf
@@ -133,11 +135,12 @@
2.4.0org.spark-project.hive
- 0.13.1a
+ 1.2.1.spark
- 0.13.1
+ 1.2.110.10.1.11.7.0
+ 1.6.01.2.48.1.14.v201310313.0.0.v201112011016
@@ -150,7 +153,10 @@
0.7.11.9.161.2.1
+
4.3.2
+
+ 3.13.4.12.10.42.10
@@ -160,8 +166,19 @@
2.4.41.1.1.71.1.2
-
- false
+ 1.2.0-incubating
+ 1.10
+
+ 2.6
+
+ 3.3.2
+ 3.2.10
+ 2.7.8
+ 1.9
+ 2.5
+ 3.5.2
+ 1.3.9
+ 0.9.2${java.home}
@@ -190,7 +207,6 @@
512m512m
-
central
@@ -248,6 +264,14 @@
false
+
+ spark-hive-staging
+ Staging Repo for Hive 1.2.1 (Spark Version)
+ https://oss.sonatype.org/content/repositories/orgspark-project-1113
+
+ true
+
+ mapr-repoMapR Repository
@@ -259,12 +283,13 @@
false
+
spring-releasesSpring Release Repositoryhttps://repo.spring.io/libs-release
- true
+ falsefalse
@@ -404,12 +429,17 @@
org.apache.commonscommons-lang3
- 3.3.2
+ ${commons-lang3.version}
+
+
+ org.apache.commons
+ commons-lang
+ ${commons-lang2.version}commons-codeccommons-codec
- 1.10
+ ${commons-codec.version}org.apache.commons
@@ -424,7 +454,12 @@
com.google.code.findbugsjsr305
- 1.3.9
+ ${jsr305.version}
+
+
+ commons-httpclient
+ commons-httpclient
+ ${httpclient.classic.version}org.apache.httpcomponents
@@ -441,6 +476,16 @@
selenium-java2.42.2test
+
+
+ com.google.guava
+ guava
+
+
+ io.netty
+ netty
+
+
@@ -626,15 +671,26 @@
com.sun.jerseyjersey-server
- 1.9
+ ${jersey.version}${hadoop.deps.scope}com.sun.jerseyjersey-core
- 1.9
+ ${jersey.version}${hadoop.deps.scope}
+
+ com.sun.jersey
+ jersey-json
+ ${jersey.version}
+
+
+ stax
+ stax-api
+
+
+ org.scala-langscala-compiler
@@ -1024,58 +1080,499 @@
hive-beeline${hive.version}${hive.deps.scope}
+
+
+ ${hive.group}
+ hive-common
+
+
+ ${hive.group}
+ hive-exec
+
+
+ ${hive.group}
+ hive-jdbc
+
+
+ ${hive.group}
+ hive-metastore
+
+
+ ${hive.group}
+ hive-service
+
+
+ ${hive.group}
+ hive-shims
+
+
+ org.apache.thrift
+ libthrift
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+ commons-logging
+ commons-logging
+
+ ${hive.group}hive-cli${hive.version}${hive.deps.scope}
+
+
+ ${hive.group}
+ hive-common
+
+
+ ${hive.group}
+ hive-exec
+
+
+ ${hive.group}
+ hive-jdbc
+
+
+ ${hive.group}
+ hive-metastore
+
+
+ ${hive.group}
+ hive-serde
+
+
+ ${hive.group}
+ hive-service
+
+
+ ${hive.group}
+ hive-shims
+
+
+ org.apache.thrift
+ libthrift
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+ commons-logging
+ commons-logging
+
+ ${hive.group}
- hive-exec
+ hive-common${hive.version}${hive.deps.scope}
+
+ ${hive.group}
+ hive-shims
+
+
+ org.apache.ant
+ ant
+
+
+ org.apache.zookeeper
+ zookeeper
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+ commons-loggingcommons-logging
+
+
+
+
+ ${hive.group}
+ hive-exec
+
+ ${hive.version}
+ ${hive.deps.scope}
+
+
+
+
+ ${hive.group}
+ hive-metastore
+
+
+ ${hive.group}
+ hive-shims
+
+
+ ${hive.group}
+ hive-ant
+
+
+
+ ${hive.group}
+ spark-client
+
+
+
+
+ ant
+ ant
+
+
+ org.apache.ant
+ ant
+ com.esotericsoftware.kryokryo
+
+ commons-codec
+ commons-codec
+
+
+ commons-httpclient
+ commons-httpclient
+ org.apache.avroavro-mapred
+
+
+ org.apache.calcite
+ calcite-core
+
+
+ org.apache.curator
+ apache-curator
+
+
+ org.apache.curator
+ curator-client
+
+
+ org.apache.curator
+ curator-framework
+
+
+ org.apache.thrift
+ libthrift
+
+
+ org.apache.thrift
+ libfb303
+
+
+ org.apache.zookeeper
+ zookeeper
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+ commons-logging
+ commons-logging
+ ${hive.group}hive-jdbc${hive.version}
- ${hive.deps.scope}
+
+
+ ${hive.group}
+ hive-common
+
+
+ ${hive.group}
+ hive-common
+
+
+ ${hive.group}
+ hive-metastore
+
+
+ ${hive.group}
+ hive-serde
+
+
+ ${hive.group}
+ hive-service
+
+
+ ${hive.group}
+ hive-shims
+
+
+ org.apache.httpcomponents
+ httpclient
+
+
+ org.apache.httpcomponents
+ httpcore
+
+
+ org.apache.curator
+ curator-framework
+
+
+ org.apache.thrift
+ libthrift
+
+
+ org.apache.thrift
+ libfb303
+
+
+ org.apache.zookeeper
+ zookeeper
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+
+
+ commons-logging
+ commons-logging
+
+
+
${hive.group}hive-metastore${hive.version}${hive.deps.scope}
+
+
+ ${hive.group}
+ hive-serde
+
+
+ ${hive.group}
+ hive-shims
+
+
+ org.apache.thrift
+ libfb303
+
+
+ org.apache.thrift
+ libthrift
+
+
+ com.google.guava
+ guava
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+
${hive.group}hive-serde${hive.version}${hive.deps.scope}
+
+ ${hive.group}
+ hive-common
+
+
+ ${hive.group}
+ hive-shims
+
+
+ commons-codec
+ commons-codec
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ org.apache.avro
+ avro
+
+
+ org.apache.thrift
+ libthrift
+
+
+ org.apache.thrift
+ libfb303
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+ commons-loggingcommons-logging
+
+
+
+
+ ${hive.group}
+ hive-service
+ ${hive.version}
+ ${hive.deps.scope}
+
+
+ ${hive.group}
+ hive-common
+
+
+ ${hive.group}
+ hive-exec
+
+
+ ${hive.group}
+ hive-metastore
+
+
+ ${hive.group}
+ hive-shims
+
+
+ commons-codec
+ commons-codec
+
+
+ org.apache.curator
+ curator-framework
+
+
+ org.apache.curator
+ curator-recipes
+
+
+ org.apache.thrift
+ libfb303
+
+
+ org.apache.thrift
+ libthrift
+
+
+
+
+
+
+ ${hive.group}
+ hive-shims
+ ${hive.version}
+ ${hive.deps.scope}
+
+
+ com.google.guava
+ guava
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-resourcemanager
+
+
+ org.apache.curator
+ curator-framework
+
+
+ org.apache.thrift
+ libthrift
+
+
+ org.apache.zookeeper
+ zookeeper
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ log4j
+ log4j
+ commons-logging
- commons-logging-api
+ commons-logging
@@ -1097,6 +1594,12 @@
${parquet.version}${parquet.test.deps.scope}
+
+ com.twitter
+ parquet-hadoop-bundle
+ ${hive.parquet.version}
+ runtime
+ org.apache.flumeflume-ng-core
@@ -1137,6 +1640,125 @@
+
+ org.apache.calcite
+ calcite-core
+ ${calcite.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+ com.google.guava
+ guava
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ org.codehaus.janino
+ janino
+
+
+
+ org.hsqldb
+ hsqldb
+
+
+ org.pentaho
+ pentaho-aggdesigner-algorithm
+
+
+
+
+ org.apache.calcite
+ calcite-avatica
+ ${calcite.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+
+ org.codehaus.janino
+ janino
+ ${janino.version}
+
+
+ joda-time
+ joda-time
+ ${joda.version}
+
+
+ org.jodd
+ jodd-core
+ ${jodd.version}
+
+
+ org.datanucleus
+ datanucleus-core
+ ${datanucleus-core.version}
+
+
+ org.apache.thrift
+ libthrift
+ ${libthrift.version}
+
+
+ org.apache.httpcomponents
+ httpclient
+
+
+ org.apache.httpcomponents
+ httpcore
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ org.apache.thrift
+ libfb303
+ ${libthrift.version}
+
+
+ org.apache.httpcomponents
+ httpclient
+
+
+ org.apache.httpcomponents
+ httpcore
+
+
+ org.slf4j
+ slf4j-api
+
+
+
@@ -1156,7 +1778,7 @@
- 3.0.4
+ ${maven.version}${java.version}
@@ -1174,7 +1796,7 @@
net.alchim31.mavenscala-maven-plugin
- 3.2.0
+ 3.2.2eclipse-add-source
@@ -1225,6 +1847,7 @@
${java.version}-target${java.version}
+ -Xlint:all,-serial,-path
@@ -1238,6 +1861,9 @@
UTF-81024mtrue
+
+ -Xlint:all,-serial,-path
+
@@ -1269,10 +1895,13 @@
${project.build.directory}/tmp${spark.test.home}1
+ falsefalsefalsetruetrue
+
+ srcfalse
@@ -1307,6 +1936,8 @@
falsetruetrue
+
+ __not_used__
@@ -1376,7 +2007,12 @@
org.apache.maven.pluginsmaven-assembly-plugin
- 2.5.3
+ 2.5.5
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+ 2.4.1org.apache.maven.plugins
@@ -1470,11 +2106,8 @@
org.apache.maven.pluginsmaven-shade-plugin
- 2.3false
-
- ${create.dependency.reduced.pom}
@@ -1642,6 +2275,7 @@
kinesis-aslextras/kinesis-asl
+ extras/kinesis-asl-assembly
@@ -1835,26 +2469,6 @@
-
-
- release
-
-
- true
-
-
-
| Partitioning |
+ * | +--------------+ +--------------+
+ * | ^
+ * | |
+ * | compatibleWith
+ * | |
+ * +------------+
+ *
+ */
sealed trait Partitioning {
/** Returns the number of partitions that the data is split across */
val numPartitions: Int
@@ -87,15 +119,68 @@ sealed trait Partitioning {
def satisfies(required: Distribution): Boolean
/**
- * Returns true iff all distribution guarantees made by this partitioning can also be made
- * for the `other` specified partitioning.
- * For example, two [[HashPartitioning HashPartitioning]]s are
- * only compatible if the `numPartitions` of them is the same.
+ * Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
+ * guarantees the same partitioning scheme described by `other`.
+ *
+ * Compatibility of partitionings is only checked for operators that have multiple children
+ * and that require a specific child output [[Distribution]], such as joins.
+ *
+ * Intuitively, partitionings are compatible if they route the same partitioning key to the same
+ * partition. For instance, two hash partitionings are only compatible if they produce the same
+ * number of output partitionings and hash records according to the same hash function and
+ * same partitioning key schema.
+ *
+ * Put another way, two partitionings are compatible with each other if they satisfy all of the
+ * same distribution guarantees.
*/
def compatibleWith(other: Partitioning): Boolean
- /** Returns the expressions that are used to key the partitioning. */
- def keyExpressions: Seq[Expression]
+ /**
+ * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] guarantees
+ * the same partitioning scheme described by `other`. If a `A.guarantees(B)`, then repartitioning
+ * the child's output according to `B` will be unnecessary. `guarantees` is used as a performance
+ * optimization to allow the exchange planner to avoid redundant repartitionings. By default,
+ * a partitioning only guarantees partitionings that are equal to itself (i.e. the same number
+ * of partitions, same strategy (range or hash), etc).
+ *
+ * In order to enable more aggressive optimization, this strict equality check can be relaxed.
+ * For example, say that the planner needs to repartition all of an operator's children so that
+ * they satisfy the [[AllTuples]] distribution. One way to do this is to repartition all children
+ * to have the [[SinglePartition]] partitioning. If one of the operator's children already happens
+ * to be hash-partitioned with a single partition then we do not need to re-shuffle this child;
+ * this repartitioning can be avoided if a single-partition [[HashPartitioning]] `guarantees`
+ * [[SinglePartition]].
+ *
+ * The SinglePartition example given above is not particularly interesting; guarantees' real
+ * value occurs for more advanced partitioning strategies. SPARK-7871 will introduce a notion
+ * of null-safe partitionings, under which partitionings can specify whether rows whose
+ * partitioning keys contain null values will be grouped into the same partition or whether they
+ * will have an unknown / random distribution. If a partitioning does not require nulls to be
+ * clustered then a partitioning which _does_ cluster nulls will guarantee the null clustered
+ * partitioning. The converse is not true, however: a partitioning which clusters nulls cannot
+ * be guaranteed by one which does not cluster them. Thus, in general `guarantees` is not a
+ * symmetric relation.
+ *
+ * Another way to think about `guarantees`: if `A.guarantees(B)`, then any partitioning of rows
+ * produced by `A` could have also been produced by `B`.
+ */
+ def guarantees(other: Partitioning): Boolean = this == other
+}
+
+object Partitioning {
+ def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
+ // Note: this assumes transitivity
+ partitionings.sliding(2).map {
+ case Seq(a) => true
+ case Seq(a, b) =>
+ if (a.numPartitions != b.numPartitions) {
+ assert(!a.compatibleWith(b) && !b.compatibleWith(a))
+ false
+ } else {
+ a.compatibleWith(b) && b.compatibleWith(a)
+ }
+ }.forall(_ == true)
+ }
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -104,12 +189,9 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case UnknownPartitioning(_) => true
- case _ => false
- }
+ override def compatibleWith(other: Partitioning): Boolean = false
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = false
}
case object SinglePartition extends Partitioning {
@@ -117,25 +199,9 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case SinglePartition => true
- case _ => false
- }
-
- override def keyExpressions: Seq[Expression] = Nil
-}
-
-case object BroadcastPartitioning extends Partitioning {
- val numPartitions = 1
-
- override def satisfies(required: Distribution): Boolean = true
-
- override def compatibleWith(other: Partitioning): Boolean = other match {
- case SinglePartition => true
- case _ => false
- }
+ override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = other.numPartitions == 1
}
/**
@@ -150,22 +216,23 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- private[this] lazy val clusteringSet = expressions.toSet
-
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
- clusteringSet.subsetOf(requiredClustering.toSet)
+ expressions.toSet.subsetOf(requiredClustering.toSet)
case _ => false
}
override def compatibleWith(other: Partitioning): Boolean = other match {
- case BroadcastPartitioning => true
- case h: HashPartitioning if h == this => true
+ case o: HashPartitioning => this == o
+ case _ => false
+ }
+
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case o: HashPartitioning => this == o
case _ => false
}
- override def keyExpressions: Seq[Expression] = expressions
}
/**
@@ -187,23 +254,79 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- private[this] lazy val clusteringSet = ordering.map(_.child).toSet
-
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
- clusteringSet.subsetOf(requiredClustering.toSet)
+ ordering.map(_.child).toSet.subsetOf(requiredClustering.toSet)
case _ => false
}
override def compatibleWith(other: Partitioning): Boolean = other match {
- case BroadcastPartitioning => true
- case r: RangePartitioning if r == this => true
+ case o: RangePartitioning => this == o
+ case _ => false
+ }
+
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case o: RangePartitioning => this == o
case _ => false
}
+}
- override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+/**
+ * A collection of [[Partitioning]]s that can be used to describe the partitioning
+ * scheme of the output of a physical operator. It is usually used for an operator
+ * that has multiple children. In this case, a [[Partitioning]] in this collection
+ * describes how this operator's output is partitioned based on expressions from
+ * a child. For example, for a Join operator on two tables `A` and `B`
+ * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema,
+ * there are two [[Partitioning]]s can be used to describe how the output of
+ * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and
+ * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
+ * in this collection do not need to be equivalent, which is useful for
+ * Outer Join operators.
+ */
+case class PartitioningCollection(partitionings: Seq[Partitioning])
+ extends Expression with Partitioning with Unevaluable {
+
+ require(
+ partitionings.map(_.numPartitions).distinct.length == 1,
+ s"PartitioningCollection requires all of its partitionings have the same numPartitions.")
+
+ override def children: Seq[Expression] = partitionings.collect {
+ case expr: Expression => expr
+ }
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = IntegerType
+
+ override val numPartitions = partitionings.map(_.numPartitions).distinct.head
+
+ /**
+ * Returns true if any `partitioning` of this collection satisfies the given
+ * [[Distribution]].
+ */
+ override def satisfies(required: Distribution): Boolean =
+ partitionings.exists(_.satisfies(required))
+
+ /**
+ * Returns true if any `partitioning` of this collection is compatible with
+ * the given [[Partitioning]].
+ */
+ override def compatibleWith(other: Partitioning): Boolean =
+ partitionings.exists(_.compatibleWith(other))
+
+ /**
+ * Returns true if any `partitioning` of this collection guarantees
+ * the given [[Partitioning]].
+ */
+ override def guarantees(other: Partitioning): Boolean =
+ partitionings.exists(_.guarantees(other))
+
+ override def toString: String = {
+ partitionings.map(_.toString).mkString("(", " or ", ")")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 3f9858b0c4a43..f80d2a93241d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -17,10 +17,30 @@
package org.apache.spark.sql.catalyst.rules
+import scala.collection.JavaConverters._
+
+import com.google.common.util.concurrent.AtomicLongMap
+
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
+object RuleExecutor {
+ protected val timeMap = AtomicLongMap.create[String]()
+
+ /** Resets statistics about time spent running specific rules */
+ def resetTime(): Unit = timeMap.clear()
+
+ /** Dump statistics about time spent running specific rules. */
+ def dumpTimeSpent(): String = {
+ val map = timeMap.asMap().asScala
+ val maxSize = map.keys.map(_.toString.length).max
+ map.toSeq.sortBy(_._2).reverseMap { case (k, v) =>
+ s"${k.padTo(maxSize, " ").mkString} $v"
+ }.mkString("\n")
+ }
+}
+
abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
/**
@@ -41,6 +61,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected val batches: Seq[Batch]
+
/**
* Executes the batches of rules defined by the subclass. The batches are executed serially
* using the defined execution strategy. Within each batch, rules are also executed serially.
@@ -58,7 +79,11 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
while (continue) {
curPlan = batch.rules.foldLeft(curPlan) {
case (plan, rule) =>
+ val startTime = System.nanoTime()
val result = rule(plan)
+ val runTime = System.nanoTime() - startTime
+ RuleExecutor.timeMap.addAndGet(rule.ruleName, runTime)
+
if (!result.fastEquals(plan)) {
logTrace(
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 122e9fc5ed77f..7971e25188e8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -149,7 +149,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* Returns a copy of this node where `f` has been applied to all the nodes children.
*/
- def mapChildren(f: BaseType => BaseType): this.type = {
+ def mapChildren(f: BaseType => BaseType): BaseType = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
@@ -170,7 +170,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* Returns a copy of this node with the children replaced.
* TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
*/
- def withNewChildren(newChildren: Seq[BaseType]): this.type = {
+ def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
assert(newChildren.size == children.size, "Incorrect number of children")
var changed = false
val remainingNewChildren = newChildren.toBuffer
@@ -229,9 +229,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
- transformChildrenDown(rule)
+ transformChildren(rule, (t, r) => t.transformDown(r))
} else {
- afterRule.transformChildrenDown(rule)
+ afterRule.transformChildren(rule, (t, r) => t.transformDown(r))
}
}
@@ -240,11 +240,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* this node. When `rule` does not apply to a given node it is left unchanged.
* @param rule the function used to transform this nodes children
*/
- def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = {
+ protected def transformChildren(
+ rule: PartialFunction[BaseType, BaseType],
+ nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
var changed = false
val newArgs = productIterator.map {
case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
@@ -252,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case Some(arg: TreeNode[_]) if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
Some(newChild)
@@ -263,7 +265,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
newChild
@@ -285,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
- val afterRuleOnChildren = transformChildrenUp(rule)
+ val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
@@ -297,44 +299,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}
- def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = {
- var changed = false
- val newArgs = productIterator.map {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- newChild
- } else {
- arg
- }
- case Some(arg: TreeNode[_]) if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- Some(newChild)
- } else {
- Some(arg)
- }
- case m: Map[_, _] => m
- case d: DataType => d // Avoid unpacking Structs
- case args: Traversable[_] => args.map {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- newChild
- } else {
- arg
- }
- case other => other
- }
- case nonChild: AnyRef => nonChild
- case null => null
- }.toArray
- if (changed) makeCopy(newArgs) else this
- }
-
/**
* Args to the constructor that should be copied, but not transformed.
* These are appended to the transformed args automatically by makeCopy
@@ -348,7 +312,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* that are not present in the productIterator.
* @param newArgs the new product arguments.
*/
- def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+ def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") {
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
@@ -359,9 +323,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
if (otherCopyArgs.isEmpty) {
- defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
+ defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType]
} else {
- defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType]
}
}
} catch {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 53abdf6618eac..672620460c3c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -31,6 +31,11 @@ import org.apache.spark.unsafe.types.UTF8String
* precision.
*/
object DateTimeUtils {
+
+ // we use Int and Long internally to represent [[DateType]] and [[TimestampType]]
+ type SQLDate = Int
+ type SQLTimestamp = Long
+
// see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian
final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5
final val SECONDS_PER_DAY = 60 * 60 * 24L
@@ -72,7 +77,7 @@ object DateTimeUtils {
}
// we should use the exact day as Int, for example, (year, month, day) -> day
- def millisToDays(millisUtc: Long): Int = {
+ def millisToDays(millisUtc: Long): SQLDate = {
// SPARK-6785: use Math.floor so negative number of days (dates before 1970)
// will correctly work as input for function toJavaDate(Int)
val millisLocal = millisUtc + threadLocalLocalTimeZone.get().getOffset(millisUtc)
@@ -80,16 +85,16 @@ object DateTimeUtils {
}
// reverse of millisToDays
- def daysToMillis(days: Int): Long = {
+ def daysToMillis(days: SQLDate): Long = {
val millisUtc = days.toLong * MILLIS_PER_DAY
millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc)
}
- def dateToString(days: Int): String =
+ def dateToString(days: SQLDate): String =
threadLocalDateFormat.get.format(toJavaDate(days))
// Converts Timestamp to string according to Hive TimestampWritable convention.
- def timestampToString(us: Long): String = {
+ def timestampToString(us: SQLTimestamp): String = {
val ts = toJavaTimestamp(us)
val timestampString = ts.toString
val formatted = threadLocalTimestampFormat.get.format(ts)
@@ -132,21 +137,21 @@ object DateTimeUtils {
/**
* Returns the number of days since epoch from from java.sql.Date.
*/
- def fromJavaDate(date: Date): Int = {
+ def fromJavaDate(date: Date): SQLDate = {
millisToDays(date.getTime)
}
/**
* Returns a java.sql.Date from number of days since epoch.
*/
- def toJavaDate(daysSinceEpoch: Int): Date = {
+ def toJavaDate(daysSinceEpoch: SQLDate): Date = {
new Date(daysToMillis(daysSinceEpoch))
}
/**
* Returns a java.sql.Timestamp from number of micros since epoch.
*/
- def toJavaTimestamp(us: Long): Timestamp = {
+ def toJavaTimestamp(us: SQLTimestamp): Timestamp = {
// setNanos() will overwrite the millisecond part, so the milliseconds should be
// cut off at seconds
var seconds = us / MICROS_PER_SECOND
@@ -164,7 +169,7 @@ object DateTimeUtils {
/**
* Returns the number of micros since epoch from java.sql.Timestamp.
*/
- def fromJavaTimestamp(t: Timestamp): Long = {
+ def fromJavaTimestamp(t: Timestamp): SQLTimestamp = {
if (t != null) {
t.getTime() * 1000L + (t.getNanos().toLong / 1000) % 1000L
} else {
@@ -176,7 +181,7 @@ object DateTimeUtils {
* Returns the number of microseconds since epoch from Julian day
* and nanoseconds in a day
*/
- def fromJulianDay(day: Int, nanoseconds: Long): Long = {
+ def fromJulianDay(day: Int, nanoseconds: Long): SQLTimestamp = {
// use Long to avoid rounding errors
val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2
seconds * MICROS_PER_SECOND + nanoseconds / 1000L
@@ -185,7 +190,7 @@ object DateTimeUtils {
/**
* Returns Julian day and nanoseconds in a day from the number of microseconds
*/
- def toJulianDay(us: Long): (Int, Long) = {
+ def toJulianDay(us: SQLTimestamp): (Int, Long) = {
val seconds = us / MICROS_PER_SECOND + SECONDS_PER_DAY / 2
val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH
val secondsInDay = seconds % SECONDS_PER_DAY
@@ -219,7 +224,7 @@ object DateTimeUtils {
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]-[h]h:[m]m`
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m`
*/
- def stringToTimestamp(s: UTF8String): Option[Long] = {
+ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = {
if (s == null) {
return None
}
@@ -355,7 +360,7 @@ object DateTimeUtils {
* `yyyy-[m]m-[d]d *`
* `yyyy-[m]m-[d]dT*`
*/
- def stringToDate(s: UTF8String): Option[Int] = {
+ def stringToDate(s: UTF8String): Option[SQLDate] = {
if (s == null) {
return None
}
@@ -394,7 +399,7 @@ object DateTimeUtils {
/**
* Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds.
*/
- def getHours(timestamp: Long): Int = {
+ def getHours(timestamp: SQLTimestamp): Int = {
val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000)
((localTs / 1000 / 3600) % 24).toInt
}
@@ -403,7 +408,7 @@ object DateTimeUtils {
* Returns the minute value of a given timestamp value. The timestamp is expressed in
* microseconds.
*/
- def getMinutes(timestamp: Long): Int = {
+ def getMinutes(timestamp: SQLTimestamp): Int = {
val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000)
((localTs / 1000 / 60) % 60).toInt
}
@@ -412,7 +417,7 @@ object DateTimeUtils {
* Returns the second value of a given timestamp value. The timestamp is expressed in
* microseconds.
*/
- def getSeconds(timestamp: Long): Int = {
+ def getSeconds(timestamp: SQLTimestamp): Int = {
((timestamp / 1000 / 1000) % 60).toInt
}
@@ -447,7 +452,7 @@ object DateTimeUtils {
* The calculation uses the fact that the period 1.1.2001 until 31.12.2400 is
* equals to the period 1.1.1601 until 31.12.2000.
*/
- private[this] def getYearAndDayInYear(daysSince1970: Int): (Int, Int) = {
+ private[this] def getYearAndDayInYear(daysSince1970: SQLDate): (Int, Int) = {
// add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999)
val daysNormalized = daysSince1970 + toYearZero
val numOfQuarterCenturies = daysNormalized / daysIn400Years
@@ -461,7 +466,7 @@ object DateTimeUtils {
* Returns the 'day in year' value for the given date. The date is expressed in days
* since 1.1.1970.
*/
- def getDayInYear(date: Int): Int = {
+ def getDayInYear(date: SQLDate): Int = {
getYearAndDayInYear(date)._2
}
@@ -469,7 +474,7 @@ object DateTimeUtils {
* Returns the year value for the given date. The date is expressed in days
* since 1.1.1970.
*/
- def getYear(date: Int): Int = {
+ def getYear(date: SQLDate): Int = {
getYearAndDayInYear(date)._1
}
@@ -477,7 +482,7 @@ object DateTimeUtils {
* Returns the quarter for the given date. The date is expressed in days
* since 1.1.1970.
*/
- def getQuarter(date: Int): Int = {
+ def getQuarter(date: SQLDate): Int = {
var (year, dayInYear) = getYearAndDayInYear(date)
if (isLeapYear(year)) {
dayInYear = dayInYear - 1
@@ -493,11 +498,55 @@ object DateTimeUtils {
}
}
+ /**
+ * Split date (expressed in days since 1.1.1970) into four fields:
+ * year, month (Jan is Month 1), dayInMonth, daysToMonthEnd (0 if it's last day of month).
+ */
+ def splitDate(date: SQLDate): (Int, Int, Int, Int) = {
+ var (year, dayInYear) = getYearAndDayInYear(date)
+ val isLeap = isLeapYear(year)
+ if (isLeap && dayInYear == 60) {
+ (year, 2, 29, 0)
+ } else {
+ if (isLeap && dayInYear > 60) dayInYear -= 1
+
+ if (dayInYear <= 181) {
+ if (dayInYear <= 31) {
+ (year, 1, dayInYear, 31 - dayInYear)
+ } else if (dayInYear <= 59) {
+ (year, 2, dayInYear - 31, if (isLeap) 60 - dayInYear else 59 - dayInYear)
+ } else if (dayInYear <= 90) {
+ (year, 3, dayInYear - 59, 90 - dayInYear)
+ } else if (dayInYear <= 120) {
+ (year, 4, dayInYear - 90, 120 - dayInYear)
+ } else if (dayInYear <= 151) {
+ (year, 5, dayInYear - 120, 151 - dayInYear)
+ } else {
+ (year, 6, dayInYear - 151, 181 - dayInYear)
+ }
+ } else {
+ if (dayInYear <= 212) {
+ (year, 7, dayInYear - 181, 212 - dayInYear)
+ } else if (dayInYear <= 243) {
+ (year, 8, dayInYear - 212, 243 - dayInYear)
+ } else if (dayInYear <= 273) {
+ (year, 9, dayInYear - 243, 273 - dayInYear)
+ } else if (dayInYear <= 304) {
+ (year, 10, dayInYear - 273, 304 - dayInYear)
+ } else if (dayInYear <= 334) {
+ (year, 11, dayInYear - 304, 334 - dayInYear)
+ } else {
+ (year, 12, dayInYear - 334, 365 - dayInYear)
+ }
+ }
+ }
+ }
+
/**
* Returns the month value for the given date. The date is expressed in days
* since 1.1.1970. January is month 1.
*/
- def getMonth(date: Int): Int = {
+ def getMonth(date: SQLDate): Int = {
var (year, dayInYear) = getYearAndDayInYear(date)
if (isLeapYear(year)) {
if (dayInYear == 60) {
@@ -538,7 +587,7 @@ object DateTimeUtils {
* Returns the 'day of month' value for the given date. The date is expressed in days
* since 1.1.1970.
*/
- def getDayOfMonth(date: Int): Int = {
+ def getDayOfMonth(date: SQLDate): Int = {
var (year, dayInYear) = getYearAndDayInYear(date)
if (isLeapYear(year)) {
if (dayInYear == 60) {
@@ -584,7 +633,7 @@ object DateTimeUtils {
* Returns the date value for the first day of the given month.
* The month is expressed in months since year zero (17999 BC), starting from 0.
*/
- private def firstDayOfMonth(absoluteMonth: Int): Int = {
+ private def firstDayOfMonth(absoluteMonth: Int): SQLDate = {
val absoluteYear = absoluteMonth / 12
var monthInYear = absoluteMonth - absoluteYear * 12
var date = getDateFromYear(absoluteYear)
@@ -602,7 +651,7 @@ object DateTimeUtils {
* Returns the date value for January 1 of the given year.
* The year is expressed in years since year zero (17999 BC), starting from 0.
*/
- private def getDateFromYear(absoluteYear: Int): Int = {
+ private def getDateFromYear(absoluteYear: Int): SQLDate = {
val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100
+ absoluteYear / 4)
absoluteDays - toYearZero
@@ -612,73 +661,35 @@ object DateTimeUtils {
* Add date and year-month interval.
* Returns a date value, expressed in days since 1.1.1970.
*/
- def dateAddMonths(days: Int, months: Int): Int = {
- val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months
- val currentMonthInYear = absoluteMonth % 12
- val currentYear = absoluteMonth / 12
+ def dateAddMonths(days: SQLDate, months: Int): SQLDate = {
+ val (year, monthInYear, dayOfMonth, daysToMonthEnd) = splitDate(days)
+ val absoluteMonth = (year - YearZero) * 12 + monthInYear - 1 + months
+ val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
+ val currentMonthInYear = nonNegativeMonth % 12
+ val currentYear = nonNegativeMonth / 12
+
val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay
- val dayOfMonth = getDayOfMonth(days)
- val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) {
+ val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) {
// last day of the month
lastDayOfMonth
} else {
dayOfMonth
}
- firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1
+ firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1
}
/**
* Add timestamp and full interval.
* Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00.
*/
- def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = {
+ def timestampAddInterval(start: SQLTimestamp, months: Int, microseconds: Long): SQLTimestamp = {
val days = millisToDays(start / 1000L)
val newDays = dateAddMonths(days, months)
daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds
}
- /**
- * Returns the last dayInMonth in the month it belongs to. The date is expressed
- * in days since 1.1.1970. the return value starts from 1.
- */
- private def getLastDayInMonthOfMonth(date: Int): Int = {
- var (year, dayInYear) = getYearAndDayInYear(date)
- if (isLeapYear(year)) {
- if (dayInYear > 31 && dayInYear <= 60) {
- return 29
- } else if (dayInYear > 60) {
- dayInYear = dayInYear - 1
- }
- }
- if (dayInYear <= 31) {
- 31
- } else if (dayInYear <= 59) {
- 28
- } else if (dayInYear <= 90) {
- 31
- } else if (dayInYear <= 120) {
- 30
- } else if (dayInYear <= 151) {
- 31
- } else if (dayInYear <= 181) {
- 30
- } else if (dayInYear <= 212) {
- 31
- } else if (dayInYear <= 243) {
- 31
- } else if (dayInYear <= 273) {
- 30
- } else if (dayInYear <= 304) {
- 31
- } else if (dayInYear <= 334) {
- 30
- } else {
- 31
- }
- }
-
/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
* microseconds since 1.1.1970.
@@ -689,19 +700,18 @@ object DateTimeUtils {
* Otherwise, the difference is calculated based on 31 days per month, and rounding to
* 8 digits.
*/
- def monthsBetween(time1: Long, time2: Long): Double = {
+ def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = {
val millis1 = time1 / 1000L
val millis2 = time2 / 1000L
val date1 = millisToDays(millis1)
val date2 = millisToDays(millis2)
- // TODO(davies): get year, month, dayOfMonth from single function
- val dayInMonth1 = getDayOfMonth(date1)
- val dayInMonth2 = getDayOfMonth(date2)
- val months1 = getYear(date1) * 12 + getMonth(date1)
- val months2 = getYear(date2) * 12 + getMonth(date2)
-
- if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1)
- && dayInMonth2 == getLastDayInMonthOfMonth(date2))) {
+ val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1)
+ val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2)
+
+ val months1 = year1 * 12 + monthInYear1
+ val months2 = year2 * 12 + monthInYear2
+
+ if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) {
return (months1 - months2).toDouble
}
// milliseconds is enough for 8 digits precision on the right side
@@ -735,7 +745,7 @@ object DateTimeUtils {
* Returns the first date which is later than startDate and is of the given dayOfWeek.
* dayOfWeek is an integer ranges in [0, 6], and 0 is Thu, 1 is Fri, etc,.
*/
- def getNextDateForDayOfWeek(startDate: Int, dayOfWeek: Int): Int = {
+ def getNextDateForDayOfWeek(startDate: SQLDate, dayOfWeek: Int): SQLDate = {
startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7
}
@@ -743,40 +753,63 @@ object DateTimeUtils {
* Returns last day of the month for the given date. The date is expressed in days
* since 1.1.1970.
*/
- def getLastDayOfMonth(date: Int): Int = {
- var (year, dayInYear) = getYearAndDayInYear(date)
- if (isLeapYear(year)) {
- if (dayInYear > 31 && dayInYear <= 60) {
- return date + (60 - dayInYear)
- } else if (dayInYear > 60) {
- dayInYear = dayInYear - 1
- }
+ def getLastDayOfMonth(date: SQLDate): SQLDate = {
+ val (_, _, _, daysToMonthEnd) = splitDate(date)
+ date + daysToMonthEnd
+ }
+
+ private val TRUNC_TO_YEAR = 1
+ private val TRUNC_TO_MONTH = 2
+ private val TRUNC_INVALID = -1
+
+ /**
+ * Returns the trunc date from original date and trunc level.
+ * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2.
+ */
+ def truncDate(d: SQLDate, level: Int): SQLDate = {
+ if (level == TRUNC_TO_YEAR) {
+ d - DateTimeUtils.getDayInYear(d) + 1
+ } else if (level == TRUNC_TO_MONTH) {
+ d - DateTimeUtils.getDayOfMonth(d) + 1
+ } else {
+ // caller make sure that this should never be reached
+ sys.error(s"Invalid trunc level: $level")
}
- val lastDayOfMonthInYear = if (dayInYear <= 31) {
- 31
- } else if (dayInYear <= 59) {
- 59
- } else if (dayInYear <= 90) {
- 90
- } else if (dayInYear <= 120) {
- 120
- } else if (dayInYear <= 151) {
- 151
- } else if (dayInYear <= 181) {
- 181
- } else if (dayInYear <= 212) {
- 212
- } else if (dayInYear <= 243) {
- 243
- } else if (dayInYear <= 273) {
- 273
- } else if (dayInYear <= 304) {
- 304
- } else if (dayInYear <= 334) {
- 334
+ }
+
+ /**
+ * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID,
+ * TRUNC_INVALID means unsupported truncate level.
+ */
+ def parseTruncLevel(format: UTF8String): Int = {
+ if (format == null) {
+ TRUNC_INVALID
} else {
- 365
+ format.toString.toUpperCase match {
+ case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR
+ case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH
+ case _ => TRUNC_INVALID
+ }
}
- date + (lastDayOfMonthInYear - dayInYear)
+ }
+
+ /**
+ * Returns a timestamp of given timezone from utc timestamp, with the same string
+ * representation in their timezone.
+ */
+ def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = {
+ val tz = TimeZone.getTimeZone(timeZone)
+ val offset = tz.getOffset(time / 1000L)
+ time + offset * 1000L
+ }
+
+ /**
+ * Returns a utc timestamp from a given timestamp from a given timezone, with the same
+ * string representation in their timezone.
+ */
+ def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = {
+ val tz = TimeZone.getTimeZone(timeZone)
+ val offset = tz.getOffset(time / 1000L)
+ time - offset * 1000L
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
new file mode 100644
index 0000000000000..9ddfb3a0d3759
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.util.regex.Pattern
+
+object StringUtils {
+
+ // replace the _ with .{1} exactly match 1 time of any character
+ // replace the % with .*, match 0 or more times with any character
+ def escapeLikeRegex(v: String): String = {
+ if (!v.isEmpty) {
+ "(?s)" + (' ' +: v.init).zip(v).flatMap {
+ case (prev, '\\') => ""
+ case ('\\', c) =>
+ c match {
+ case '_' => "_"
+ case '%' => "%"
+ case _ => Pattern.quote("\\" + c)
+ }
+ case (prev, c) =>
+ c match {
+ case '_' => "."
+ case '%' => ".*"
+ case _ => Pattern.quote(Character.toString(c))
+ }
+ }.mkString
+ } else {
+ v
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 0103ddcf9cfb7..bcf4d78fb9371 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -18,32 +18,34 @@
package org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._
/**
* Helper functions to check for valid data types.
*/
object TypeUtils {
- def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = {
- if (t.isInstanceOf[NumericType] || t == NullType) {
+ def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
+ if (dt.isInstanceOf[NumericType] || dt == NullType) {
TypeCheckResult.TypeCheckSuccess
} else {
- TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t")
+ TypeCheckResult.TypeCheckFailure(s"$caller requires numeric types, not $dt")
}
}
- def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = {
- if (t.isInstanceOf[AtomicType] || t == NullType) {
+ def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
+ if (RowOrdering.isOrderable(dt)) {
TypeCheckResult.TypeCheckSuccess
} else {
- TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t")
+ TypeCheckResult.TypeCheckFailure(s"$caller does not support ordering on type $dt")
}
}
def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
if (types.distinct.size > 1) {
TypeCheckResult.TypeCheckFailure(
- s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}")
+ s"input to $caller should all be the same type, but it's " +
+ types.map(_.simpleString).mkString("[", ", ", "]"))
} else {
TypeCheckResult.TypeCheckSuccess
}
@@ -52,8 +54,12 @@ object TypeUtils {
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
- def getOrdering(t: DataType): Ordering[Any] =
- t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]]
+ def getInterpretedOrdering(t: DataType): Ordering[Any] = {
+ t match {
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
+ case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
+ }
+ }
def compareBinary(x: Array[Byte], y: Array[Byte]): Int = {
for (i <- 0 until x.length; if i < y.length) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
new file mode 100644
index 0000000000000..f6fa021adee95
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData {
+ require(keyArray.numElements() == valueArray.numElements())
+
+ override def numElements(): Int = keyArray.numElements()
+
+ override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())
+
+ // We need to check equality of map type in tests.
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[ArrayBasedMapData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[ArrayBasedMapData]
+ if (other eq null) {
+ return false
+ }
+
+ ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other)
+ }
+
+ override def hashCode: Int = {
+ ArrayBasedMapData.toScalaMap(this).hashCode()
+ }
+
+ override def toString(): String = {
+ s"keys: $keyArray, values: $valueArray"
+ }
+}
+
+object ArrayBasedMapData {
+ def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
+ new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
+ }
+
+ def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = {
+ val keys = map.keyArray.asInstanceOf[GenericArrayData].array
+ val values = map.valueArray.asInstanceOf[GenericArrayData].array
+ keys.zip(values).toMap
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
index 14a7285877622..642c56f12ded1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -17,105 +17,118 @@
package org.apache.spark.sql.types
+import scala.reflect.ClassTag
+
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
abstract class ArrayData extends SpecializedGetters with Serializable {
- // todo: remove this after we handle all types.(map type need special getter)
- def get(ordinal: Int): Any
-
def numElements(): Int
- // todo: need a more efficient way to iterate array type.
- def toArray(): Array[Any] = {
- val n = numElements()
- val values = new Array[Any](n)
+ def copy(): ArrayData
+
+ def toBooleanArray(): Array[Boolean] = {
+ val size = numElements()
+ val values = new Array[Boolean](size)
var i = 0
- while (i < n) {
- if (isNullAt(i)) {
- values(i) = null
- } else {
- values(i) = get(i)
- }
+ while (i < size) {
+ values(i) = getBoolean(i)
i += 1
}
values
}
- override def toString(): String = toArray.mkString("[", ",", "]")
+ def toByteArray(): Array[Byte] = {
+ val size = numElements()
+ val values = new Array[Byte](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getByte(i)
+ i += 1
+ }
+ values
+ }
+
+ def toShortArray(): Array[Short] = {
+ val size = numElements()
+ val values = new Array[Short](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getShort(i)
+ i += 1
+ }
+ values
+ }
+
+ def toIntArray(): Array[Int] = {
+ val size = numElements()
+ val values = new Array[Int](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getInt(i)
+ i += 1
+ }
+ values
+ }
- override def equals(o: Any): Boolean = {
- if (!o.isInstanceOf[ArrayData]) {
- return false
+ def toLongArray(): Array[Long] = {
+ val size = numElements()
+ val values = new Array[Long](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getLong(i)
+ i += 1
}
+ values
+ }
- val other = o.asInstanceOf[ArrayData]
- if (other eq null) {
- return false
+ def toFloatArray(): Array[Float] = {
+ val size = numElements()
+ val values = new Array[Float](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getFloat(i)
+ i += 1
}
+ values
+ }
- val len = numElements()
- if (len != other.numElements()) {
- return false
+ def toDoubleArray(): Array[Double] = {
+ val size = numElements()
+ val values = new Array[Double](size)
+ var i = 0
+ while (i < size) {
+ values(i) = getDouble(i)
+ i += 1
}
+ values
+ }
+ def toArray[T: ClassTag](elementType: DataType): Array[T] = {
+ val size = numElements()
+ val values = new Array[T](size)
var i = 0
- while (i < len) {
- if (isNullAt(i) != other.isNullAt(i)) {
- return false
- }
- if (!isNullAt(i)) {
- val o1 = get(i)
- val o2 = other.get(i)
- o1 match {
- case b1: Array[Byte] =>
- if (!o2.isInstanceOf[Array[Byte]] ||
- !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
- return false
- }
- case f1: Float if java.lang.Float.isNaN(f1) =>
- if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
- return false
- }
- case d1: Double if java.lang.Double.isNaN(d1) =>
- if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
- return false
- }
- case _ => if (o1 != o2) {
- return false
- }
- }
+ while (i < size) {
+ if (isNullAt(i)) {
+ values(i) = null.asInstanceOf[T]
+ } else {
+ values(i) = get(i, elementType).asInstanceOf[T]
}
i += 1
}
- true
+ values
}
- override def hashCode: Int = {
- var result: Int = 37
+ // todo: specialize this.
+ def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = {
+ val size = numElements()
var i = 0
- val len = numElements()
- while (i < len) {
- val update: Int =
- if (isNullAt(i)) {
- 0
- } else {
- get(i) match {
- case b: Boolean => if (b) 0 else 1
- case b: Byte => b.toInt
- case s: Short => s.toInt
- case i: Int => i
- case l: Long => (l ^ (l >>> 32)).toInt
- case f: Float => java.lang.Float.floatToIntBits(f)
- case d: Double =>
- val b = java.lang.Double.doubleToLongBits(d)
- (b ^ (b >>> 32)).toInt
- case a: Array[Byte] => java.util.Arrays.hashCode(a)
- case other => other.hashCode()
- }
- }
- result = 37 * result + update
+ while (i < size) {
+ if (isNullAt(i)) {
+ f(i, null)
+ } else {
+ f(i, get(i, elementType))
+ }
i += 1
}
- result
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 5094058164b2f..5770f59b53077 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
override def simpleString: String = s"array<${elementType.simpleString}>"
- private[spark] override def asNullable: ArrayType =
+ override private[spark] def asNullable: ArrayType =
ArrayType(elementType.asNullable, containsNull = true)
+
+ override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+ f(this) || elementType.existsRecursively(f)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index f4428c2e8b202..7bcd623b3f33e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def asNullable: DataType
+ /**
+ * Returns true if any `DataType` of this DataType tree satisfies the given function `f`.
+ */
+ private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)
+
override private[sql] def defaultConcreteType: DataType = this
override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index bc689810bc292..d95805c24521c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.types
+import java.math.{RoundingMode, MathContext}
+
import org.apache.spark.annotation.DeveloperApi
/**
@@ -28,7 +30,7 @@ import org.apache.spark.annotation.DeveloperApi
* - Otherwise, the decimal value is longVal / (10 ** _scale)
*/
final class Decimal extends Ordered[Decimal] with Serializable {
- import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE}
+ import org.apache.spark.sql.types.Decimal._
private var decimalVal: BigDecimal = null
private var longVal: Long = 0L
@@ -188,6 +190,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
* @return true if successful, false if overflow would occur
*/
def changePrecision(precision: Int, scale: Int): Boolean = {
+ // fast path for UnsafeProjection
+ if (precision == this.precision && scale == this.scale) {
+ return true
+ }
// First, update our longVal if we can, or transfer over to using a BigDecimal
if (decimalVal.eq(null)) {
if (scale < _scale) {
@@ -224,7 +230,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
decimalVal = newVal
} else {
// We're still using Longs, but we should check whether we match the new precision
- val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
+ val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
if (longVal <= -p || longVal >= p) {
// Note that we shouldn't have been able to fix this by switching to BigDecimal
return false
@@ -257,29 +263,44 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
- def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
+ def + (that: Decimal): Decimal = {
+ if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
+ Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale)
+ } else {
+ Decimal(toBigDecimal + that.toBigDecimal, precision, scale)
+ }
+ }
- def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
+ def - (that: Decimal): Decimal = {
+ if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
+ Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale)
+ } else {
+ Decimal(toBigDecimal - that.toBigDecimal, precision, scale)
+ }
+ }
- def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
+ // HiveTypeCoercion will take care of the precision, scale of result
+ def * (that: Decimal): Decimal =
+ Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT))
def / (that: Decimal): Decimal =
- if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
+ if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT))
def % (that: Decimal): Decimal =
- if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
+ if (that.isZero) null
+ else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT))
def remainder(that: Decimal): Decimal = this % that
def unary_- : Decimal = {
if (decimalVal.ne(null)) {
- Decimal(-decimalVal)
+ Decimal(-decimalVal, precision, scale)
} else {
Decimal(-longVal, precision, scale)
}
}
- def abs: Decimal = if (this.compare(Decimal(0)) < 0) this.unary_- else this
+ def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
}
object Decimal {
@@ -292,6 +313,11 @@ object Decimal {
private val BIG_DEC_ZERO = BigDecimal(0)
+ private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
+
+ private[sql] val ZERO = Decimal(0)
+ private[sql] val ONE = Decimal(1)
+
def apply(value: Double): Decimal = new Decimal().set(value)
def apply(value: Long): Decimal = new Decimal().set(value)
@@ -305,6 +331,9 @@ object Decimal {
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
new Decimal().set(value, precision, scale)
+ def apply(value: java.math.BigDecimal, precision: Int, scale: Int): Decimal =
+ new Decimal().set(value, precision, scale)
+
def apply(unscaled: Long, precision: Int, scale: Int): Decimal =
new Decimal().set(unscaled, precision, scale)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index 7992ba947c069..459fcb6fc0acc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -18,42 +18,107 @@
package org.apache.spark.sql.types
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-class GenericArrayData(array: Array[Any]) extends ArrayData {
- private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T]
+class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
- override def toArray(): Array[Any] = array
+ override def copy(): ArrayData = new GenericArrayData(array.clone())
- override def get(ordinal: Int): Any = array(ordinal)
-
- override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null
+ override def numElements(): Int = array.length
+ private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T]
+ override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null
+ override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal)
override def getBoolean(ordinal: Int): Boolean = getAs(ordinal)
-
override def getByte(ordinal: Int): Byte = getAs(ordinal)
-
override def getShort(ordinal: Int): Short = getAs(ordinal)
-
override def getInt(ordinal: Int): Int = getAs(ordinal)
-
override def getLong(ordinal: Int): Long = getAs(ordinal)
-
override def getFloat(ordinal: Int): Float = getAs(ordinal)
-
override def getDouble(ordinal: Int): Double = getAs(ordinal)
-
- override def getDecimal(ordinal: Int): Decimal = getAs(ordinal)
-
+ override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
-
override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
-
override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
-
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
-
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
-
- override def numElements(): Int = array.length
+ override def getMap(ordinal: Int): MapData = getAs(ordinal)
+
+ override def toString(): String = array.mkString("[", ",", "]")
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[GenericArrayData]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[GenericArrayData]
+ if (other eq null) {
+ return false
+ }
+
+ val len = numElements()
+ if (len != other.numElements()) {
+ return false
+ }
+
+ var i = 0
+ while (i < len) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = array(i)
+ val o2 = other.array(i)
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
+ return false
+ }
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ val len = numElements()
+ while (i < len) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ array(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
new file mode 100644
index 0000000000000..f50969f0f0b79
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+abstract class MapData extends Serializable {
+
+ def numElements(): Int
+
+ def keyArray(): ArrayData
+
+ def valueArray(): ArrayData
+
+ def copy(): MapData
+
+ def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = {
+ val length = numElements()
+ val keys = keyArray()
+ val values = valueArray()
+ var i = 0
+ while (i < length) {
+ f(keys.get(i, keyType), values.get(i, valueType))
+ i += 1
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index ac34b642827ca..00461e529ca0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -62,8 +62,12 @@ case class MapType(
override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
- private[spark] override def asNullable: MapType =
+ override private[spark] def asNullable: MapType =
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
+
+ override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+ f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 2ef97a427c37e..d8968ef806390 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -24,7 +24,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
/**
@@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private[sql] def merge(that: StructType): StructType =
StructType.merge(this, that).asInstanceOf[StructType]
- private[spark] override def asNullable: StructType = {
+ override private[spark] def asNullable: StructType = {
val newFields = fields.map {
case StructField(name, dataType, nullable, metadata) =>
StructField(name, dataType.asNullable, nullable = true, metadata)
@@ -300,8 +300,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(newFields)
}
-}
+ override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+ f(this) || fields.exists(field => field.dataType.existsRecursively(f))
+ }
+
+ private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType))
+}
object StructType extends AbstractDataType {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 75ae29d690770..11e0c120f4072 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -24,6 +24,7 @@ import java.math.MathContext
import scala.util.Random
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
@@ -65,6 +66,19 @@ object RandomDataGenerator {
Some(f)
}
+ /**
+ * Returns a randomly generated schema, based on the given accepted types.
+ *
+ * @param numFields the number of fields in this schema
+ * @param acceptedTypes types to draw from.
+ */
+ def randomSchema(numFields: Int, acceptedTypes: Seq[DataType]): StructType = {
+ StructType(Seq.tabulate(numFields) { i =>
+ val dt = acceptedTypes(Random.nextInt(acceptedTypes.size))
+ StructField("col_" + i, dt, nullable = true)
+ })
+ }
+
/**
* Returns a function which generates random values for the given [[DataType]], or `None` if no
* random data generator is defined for that data type. The generated values will use an external
@@ -93,8 +107,16 @@ object RandomDataGenerator {
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
+ case CalendarIntervalType => Some(() => {
+ val months = rand.nextInt(1000)
+ val ns = rand.nextLong()
+ new CalendarInterval(months, ns)
+ })
case DecimalType.Fixed(precision, scale) => Some(
- () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision)))
+ () => BigDecimal.apply(
+ rand.nextLong() % math.pow(10, precision).toLong,
+ scale,
+ new MathContext(precision)))
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index c046dbf4dc2c9..827f7ce692712 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite {
}
}
- test("HashPartitioning is the output partitioning") {
+ test("HashPartitioning (with nullSafe = true) is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
new file mode 100644
index 0000000000000..5b802ccc637dd
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/PartitioningSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}
+
+class PartitioningSuite extends SparkFunSuite {
+ test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
+ val expressions = Seq(Literal(2), Literal(3))
+ // Consider two HashPartitionings that have the same _set_ of hash expressions but which are
+ // created with different orderings of those expressions:
+ val partitioningA = HashPartitioning(expressions, 100)
+ val partitioningB = HashPartitioning(expressions.reverse, 100)
+ // These partitionings are not considered equal:
+ assert(partitioningA != partitioningB)
+ // However, they both satisfy the same clustered distribution:
+ val distribution = ClusteredDistribution(expressions)
+ assert(partitioningA.satisfies(distribution))
+ assert(partitioningB.satisfies(distribution))
+ // These partitionings compute different hashcodes for the same input row:
+ def computeHashCode(partitioning: HashPartitioning): Int = {
+ val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty)
+ hashExprProj.apply(InternalRow.empty).hashCode()
+ }
+ assert(computeHashCode(partitioningA) != computeHashCode(partitioningB))
+ // Thus, these partitionings are incompatible:
+ assert(!partitioningA.compatibleWith(partitioningB))
+ assert(!partitioningB.compatibleWith(partitioningA))
+ assert(!partitioningA.guarantees(partitioningB))
+ assert(!partitioningB.guarantees(partitioningA))
+
+ // Just to be sure that we haven't cheated by having these methods always return false,
+ // check that identical partitionings are still compatible with and guarantee each other:
+ assert(partitioningA === partitioningA)
+ assert(partitioningA.guarantees(partitioningA))
+ assert(partitioningA.compatibleWith(partitioningA))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 2588df98246dd..63b475b6366c2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -42,8 +42,8 @@ case class UnresolvedTestPlan() extends LeafNode {
override def output: Seq[Attribute] = Nil
}
-class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
- import AnalysisSuite._
+class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
+ import TestRelations._
def errorTest(
name: String,
@@ -51,15 +51,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
errorMessages: Seq[String],
caseSensitive: Boolean = true): Unit = {
test(name) {
- val error = intercept[AnalysisException] {
- if (caseSensitive) {
- caseSensitiveAnalyze(plan)
- } else {
- caseInsensitiveAnalyze(plan)
- }
- }
-
- errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase)))
+ assertAnalysisError(plan, errorMessages, caseSensitive)
}
}
@@ -68,22 +60,22 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
errorTest(
"single invalid type, single arg",
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
- "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" ::
- "'null' is of type date" ::Nil)
+ "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" ::
+ "'null' is of date type" :: Nil)
errorTest(
"single invalid type, second arg",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
- "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" ::
- "'null' is of type date" ::Nil)
+ "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" ::
+ "'null' is of date type" :: Nil)
errorTest(
"multiple invalid type",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
- "expected to be of type int" :: "'null' is of type date" ::Nil)
+ "requires int type" :: "'null' is of date type" :: Nil)
errorTest(
"unresolved window function",
@@ -111,12 +103,12 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
- "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
+ "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
errorTest(
"sorting by unsupported column types",
listRelation.orderBy('list.asc),
- "sorting" :: "type" :: "array" :: Nil)
+ "sort" :: "type" :: "array" :: Nil)
errorTest(
"non-boolean filters",
@@ -169,11 +161,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
assert(plan.resolved)
- val message = intercept[AnalysisException] {
- caseSensitiveAnalyze(plan)
- }.getMessage
-
- assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
+ assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil)
}
test("error test for self-join") {
@@ -181,7 +169,61 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
val error = intercept[AnalysisException] {
SimpleAnalyzer.checkAnalysis(join)
}
- error.message.contains("Failure when resolving conflicting references in Join")
- error.message.contains("Conflicting attributes")
+ assert(error.message.contains("Failure when resolving conflicting references in Join"))
+ assert(error.message.contains("Conflicting attributes"))
+ }
+
+ test("aggregation can't work on binary and map types") {
+ val plan =
+ Aggregate(
+ AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
+ Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", BinaryType)(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+ assertAnalysisError(plan,
+ "binary type expression a cannot be used in grouping expression" :: Nil)
+
+ val plan2 =
+ Aggregate(
+ AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
+ Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+ assertAnalysisError(plan2,
+ "map type expression a cannot be used in grouping expression" :: Nil)
+ }
+
+ test("Join can't work on binary and map types") {
+ val plan =
+ Join(
+ LocalRelation(
+ AttributeReference("a", BinaryType)(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))),
+ LocalRelation(
+ AttributeReference("c", BinaryType)(exprId = ExprId(4)),
+ AttributeReference("d", IntegerType)(exprId = ExprId(3))),
+ Inner,
+ Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
+ AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
+
+ assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil)
+
+ val plan2 =
+ Join(
+ LocalRelation(
+ AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
+ AttributeReference("b", IntegerType)(exprId = ExprId(1))),
+ LocalRelation(
+ AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
+ AttributeReference("d", IntegerType)(exprId = ExprId(3))),
+ Inner,
+ Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
+ AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
+
+ assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index a86cefe941e8e..c944bc69e25b0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -24,61 +24,8 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-// todo: remove this and use AnalysisTest instead.
-object AnalysisSuite {
- val caseSensitiveConf = new SimpleCatalystConf(true)
- val caseInsensitiveConf = new SimpleCatalystConf(false)
-
- val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
- val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
-
- val caseSensitiveAnalyzer =
- new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
- override val extendedResolutionRules = EliminateSubQueries :: Nil
- }
- val caseInsensitiveAnalyzer =
- new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) {
- override val extendedResolutionRules = EliminateSubQueries :: Nil
- }
-
- def caseSensitiveAnalyze(plan: LogicalPlan): Unit =
- caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan))
-
- def caseInsensitiveAnalyze(plan: LogicalPlan): Unit =
- caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan))
-
- val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
- val testRelation2 = LocalRelation(
- AttributeReference("a", StringType)(),
- AttributeReference("b", StringType)(),
- AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType(10, 2))(),
- AttributeReference("e", ShortType)())
-
- val nestedRelation = LocalRelation(
- AttributeReference("top", StructType(
- StructField("duplicateField", StringType) ::
- StructField("duplicateField", StringType) ::
- StructField("differentCase", StringType) ::
- StructField("differentcase", StringType) :: Nil
- ))())
-
- val nestedRelation2 = LocalRelation(
- AttributeReference("top", StructType(
- StructField("aField", StringType) ::
- StructField("bField", StringType) ::
- StructField("cField", StringType) :: Nil
- ))())
-
- val listRelation = LocalRelation(
- AttributeReference("list", ArrayType(IntegerType))())
-
- caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
- caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
-}
-
-
class AnalysisSuite extends AnalysisTest {
+ import TestRelations._
test("union project *") {
val plan = (1 to 100)
@@ -165,39 +112,11 @@ class AnalysisSuite extends AnalysisTest {
test("pull out nondeterministic expressions from Sort") {
val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation)
- val analyzed = caseSensitiveAnalyzer.execute(plan)
- analyzed.transform {
- case s: Sort if s.expressions.exists(!_.deterministic) =>
- fail("nondeterministic expressions are not allowed in Sort")
- }
- }
-
- test("remove still-need-evaluate ordering expressions from sort") {
- val a = testRelation2.output(0)
- val b = testRelation2.output(1)
-
- def makeOrder(e: Expression): SortOrder = SortOrder(e, Ascending)
-
- val noEvalOrdering = makeOrder(a)
- val noEvalOrderingWithAlias = makeOrder(Alias(Alias(b, "name1")(), "name2")())
-
- val needEvalExpr = Coalesce(Seq(a, Literal("1")))
- val needEvalExpr2 = Coalesce(Seq(a, b))
- val needEvalOrdering = makeOrder(needEvalExpr)
- val needEvalOrdering2 = makeOrder(needEvalExpr2)
-
- val plan = Sort(
- Seq(noEvalOrdering, noEvalOrderingWithAlias, needEvalOrdering, needEvalOrdering2),
- false, testRelation2)
-
- val evaluatedOrdering = makeOrder(AttributeReference("_sortCondition", StringType)())
- val materializedExprs = Seq(needEvalExpr, needEvalExpr2).map(e => Alias(e, "_sortCondition")())
-
+ val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
- Project(testRelation2.output,
- Sort(Seq(makeOrder(a), makeOrder(b), evaluatedOrdering, evaluatedOrdering), false,
- Project(testRelation2.output ++ materializedExprs, testRelation2)))
-
+ Project(testRelation.output,
+ Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false,
+ Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index fdb4f28950daf..53b3695a86be5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -17,40 +17,11 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
-import org.apache.spark.sql.types._
trait AnalysisTest extends PlanTest {
- val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
-
- val testRelation2 = LocalRelation(
- AttributeReference("a", StringType)(),
- AttributeReference("b", StringType)(),
- AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType(10, 2))(),
- AttributeReference("e", ShortType)())
-
- val nestedRelation = LocalRelation(
- AttributeReference("top", StructType(
- StructField("duplicateField", StringType) ::
- StructField("duplicateField", StringType) ::
- StructField("differentCase", StringType) ::
- StructField("differentcase", StringType) :: Nil
- ))())
-
- val nestedRelation2 = LocalRelation(
- AttributeReference("top", StructType(
- StructField("aField", StringType) ::
- StructField("bField", StringType) ::
- StructField("cField", StringType) :: Nil
- ))())
-
- val listRelation = LocalRelation(
- AttributeReference("list", ArrayType(IntegerType))())
val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
val caseSensitiveConf = new SimpleCatalystConf(true)
@@ -59,8 +30,8 @@ trait AnalysisTest extends PlanTest {
val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
- caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
- caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+ caseSensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation)
+ caseInsensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation)
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
override val extendedResolutionRules = EliminateSubQueries :: Nil
@@ -100,6 +71,8 @@ trait AnalysisTest extends PlanTest {
val e = intercept[Exception] {
analyzer.checkAnalysis(analyzer.execute(inputPlan))
}
- expectedErrors.forall(e.getMessage.contains)
+ assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains),
+ s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " +
+ s"actually we get ${e.getMessage}")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index a52e4cb4dfd9f..c9bcc68f02030 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -53,9 +53,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for unary arithmetic") {
- assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)")
- assertError(Abs('stringField), "expected to be of type numeric")
- assertError(BitwiseNot('stringField), "expected to be of type integral")
+ assertError(UnaryMinus('stringField), "(numeric or calendarinterval) type")
+ assertError(Abs('stringField), "requires numeric type")
+ assertError(BitwiseNot('stringField), "requires integral type")
}
test("check types for binary arithmetic") {
@@ -78,21 +78,21 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
- assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type")
+ assertError(Add('booleanField, 'booleanField), "requires (numeric or calendarinterval) type")
assertError(Subtract('booleanField, 'booleanField),
- "accepts (numeric or calendarinterval) type")
- assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
- assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
- assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
+ "requires (numeric or calendarinterval) type")
+ assertError(Multiply('booleanField, 'booleanField), "requires numeric type")
+ assertError(Divide('booleanField, 'booleanField), "requires numeric type")
+ assertError(Remainder('booleanField, 'booleanField), "requires numeric type")
- assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type")
- assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type")
- assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type")
+ assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type")
+ assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type")
+ assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type")
assertError(MaxOf('complexField, 'complexField),
- s"accepts ${TypeCollection.Ordered.simpleString} type")
+ s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(MinOf('complexField, 'complexField),
- s"accepts ${TypeCollection.Ordered.simpleString} type")
+ s"requires ${TypeCollection.Ordered.simpleString} type")
}
test("check types for predicates") {
@@ -116,13 +116,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
assertError(LessThan('complexField, 'complexField),
- s"accepts ${TypeCollection.Ordered.simpleString} type")
+ s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(LessThanOrEqual('complexField, 'complexField),
- s"accepts ${TypeCollection.Ordered.simpleString} type")
+ s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(GreaterThan('complexField, 'complexField),
- s"accepts ${TypeCollection.Ordered.simpleString} type")
+ s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(GreaterThanOrEqual('complexField, 'complexField),
- s"accepts ${TypeCollection.Ordered.simpleString} type")
+ s"requires ${TypeCollection.Ordered.simpleString} type")
assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
@@ -145,11 +145,11 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(SumDistinct('stringField))
assertSuccess(Average('stringField))
- assertError(Min('complexField), "function min accepts non-complex type")
- assertError(Max('complexField), "function max accepts non-complex type")
- assertError(Sum('booleanField), "function sum accepts numeric type")
- assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
- assertError(Average('booleanField), "function average accepts numeric type")
+ assertError(Min('complexField), "min does not support ordering on type")
+ assertError(Max('complexField), "max does not support ordering on type")
+ assertError(Sum('booleanField), "function sum requires numeric type")
+ assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type")
+ assertError(Average('booleanField), "function average requires numeric type")
}
test("check types for others") {
@@ -181,8 +181,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(Round('intField, Literal(1)))
assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
- assertError(Round('intField, 'booleanField), "expected to be of type int")
- assertError(Round('intField, 'complexField), "expected to be of type int")
- assertError(Round('booleanField, 'intField), "expected to be of type numeric")
+ assertError(Round('intField, 'booleanField), "requires int type")
+ assertError(Round('intField, 'complexField), "requires int type")
+ assertError(Round('booleanField, 'intField), "requires numeric type")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 70608771dd110..cbdf453f600ab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}
+ test("nanvl casts") {
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),
+ NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)),
+ NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType)))
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
+ NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
+ }
+
test("type coercion for If") {
val rule = HiveTypeCoercion.IfCoercion
ruleTest(rule,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
new file mode 100644
index 0000000000000..05b870705e7ea
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.types._
+
+object TestRelations {
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
+
+ val testRelation2 = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", StringType)(),
+ AttributeReference("c", DoubleType)(),
+ AttributeReference("d", DecimalType(10, 2))(),
+ AttributeReference("e", ShortType)())
+
+ val nestedRelation = LocalRelation(
+ AttributeReference("top", StructType(
+ StructField("duplicateField", StringType) ::
+ StructField("duplicateField", StringType) ::
+ StructField("differentCase", StringType) ::
+ StructField("differentcase", StringType) :: Nil
+ ))())
+
+ val nestedRelation2 = LocalRelation(
+ AttributeReference("top", StructType(
+ StructField("aField", StringType) ::
+ StructField("bField", StringType) ::
+ StructField("cField", StringType) :: Nil
+ ))())
+
+ val listRelation = LocalRelation(
+ AttributeReference("list", ArrayType(IntegerType))())
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index d03b0fbbfb2b2..a1f15e4f0f25a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.Decimal
+import org.apache.spark.sql.types._
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+ import IntegralLiteralTestUtils._
+
/**
* Runs through the testFunc for all numeric data types.
*
@@ -47,6 +49,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Add(Literal.create(null, left.dataType), right), null)
checkEvaluation(Add(left, Literal.create(null, right.dataType)), null)
}
+ checkEvaluation(Add(positiveShortLit, negativeShortLit), -1.toShort)
+ checkEvaluation(Add(positiveIntLit, negativeIntLit), -1)
+ checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L)
}
test("- (UnaryMinus)") {
@@ -56,6 +61,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(input), convert(-1))
checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null)
}
+ checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue)
+ checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue)
+ checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue)
+ checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
+ checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort)
+ checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort)
+ checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt)
+ checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt)
+ checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong)
+ checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong)
}
test("- (Minus)") {
@@ -66,6 +81,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null)
checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null)
}
+ checkEvaluation(Subtract(positiveShortLit, negativeShortLit),
+ (positiveShort - negativeShort).toShort)
+ checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt)
+ checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong)
}
test("* (Multiply)") {
@@ -76,6 +95,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null)
checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null)
}
+ checkEvaluation(Multiply(positiveShortLit, negativeShortLit),
+ (positiveShort * negativeShort).toShort)
+ checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt)
+ checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong)
}
test("/ (Divide) basic") {
@@ -95,6 +118,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong)
+ checkEvaluation(Divide(positiveShortLit, negativeShortLit), 0.toShort)
+ checkEvaluation(Divide(positiveIntLit, negativeIntLit), 0)
+ checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L)
}
test("/ (Divide) for floating point") {
@@ -112,6 +138,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null)
checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
}
+ checkEvaluation(Remainder(positiveShortLit, positiveShortLit), 0.toShort)
+ checkEvaluation(Remainder(negativeShortLit, negativeShortLit), 0.toShort)
+ checkEvaluation(Remainder(positiveIntLit, positiveIntLit), 0)
+ checkEvaluation(Remainder(negativeIntLit, negativeIntLit), 0)
+ checkEvaluation(Remainder(positiveLongLit, positiveLongLit), 0L)
+ checkEvaluation(Remainder(negativeLongLit, negativeLongLit), 0L)
}
test("Abs") {
@@ -123,6 +155,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Abs(Literal(convert(-1))), convert(1))
checkEvaluation(Abs(Literal.create(null, dataType)), null)
}
+ checkEvaluation(Abs(positiveShortLit), positiveShort)
+ checkEvaluation(Abs(negativeShortLit), (- negativeShort).toShort)
+ checkEvaluation(Abs(positiveIntLit), positiveInt)
+ checkEvaluation(Abs(negativeIntLit), - negativeInt)
+ checkEvaluation(Abs(positiveLongLit), positiveLong)
+ checkEvaluation(Abs(negativeLongLit), - negativeLong)
}
test("MaxOf basic") {
@@ -134,6 +172,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2))
checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2))
}
+ checkEvaluation(MaxOf(positiveShortLit, negativeShortLit), (positiveShort).toShort)
+ checkEvaluation(MaxOf(positiveIntLit, negativeIntLit), positiveInt)
+ checkEvaluation(MaxOf(positiveLongLit, negativeLongLit), positiveLong)
}
test("MaxOf for atomic type") {
@@ -152,6 +193,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2))
checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1))
}
+ checkEvaluation(MinOf(positiveShortLit, negativeShortLit), (negativeShort).toShort)
+ checkEvaluation(MinOf(positiveIntLit, negativeIntLit), negativeInt)
+ checkEvaluation(MinOf(positiveLongLit, negativeLongLit), negativeLong)
}
test("MinOf for atomic type") {
@@ -170,9 +214,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
}
- checkEvaluation(Pmod(-7, 3), 2)
- checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
- checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
- checkEvaluation(Pmod(2L, Long.MaxValue), 2L)
+ checkEvaluation(Pmod(Literal(-7), Literal(3)), 2)
+ checkEvaluation(Pmod(Literal(7.2D), Literal(4.1D)), 3.1000000000000005)
+ checkEvaluation(Pmod(Literal(Decimal(0.7)), Literal(Decimal(0.2))), Decimal(0.1))
+ checkEvaluation(Pmod(Literal(2L), Literal(Long.MaxValue)), 2L)
+ checkEvaluation(Pmod(positiveShort, negativeShort), positiveShort.toShort)
+ checkEvaluation(Pmod(positiveInt, negativeInt), positiveInt)
+ checkEvaluation(Pmod(positiveLong, negativeLong), positiveLong)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
index fa30fbe528479..4fc1c06153595 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala
@@ -23,6 +23,8 @@ import org.apache.spark.sql.types._
class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ import IntegralLiteralTestUtils._
+
test("BitwiseNOT") {
def check(input: Any, expected: Any): Unit = {
val expr = BitwiseNot(Literal(input))
@@ -37,6 +39,12 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
check(123456789123L, ~123456789123L)
checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null)
+ checkEvaluation(BitwiseNot(positiveShortLit), (~positiveShort).toShort)
+ checkEvaluation(BitwiseNot(negativeShortLit), (~negativeShort).toShort)
+ checkEvaluation(BitwiseNot(positiveIntLit), ~positiveInt)
+ checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt)
+ checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong)
+ checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong)
}
test("BitwiseAnd") {
@@ -56,6 +64,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(BitwiseAnd(nullLit, Literal(1)), null)
checkEvaluation(BitwiseAnd(Literal(1), nullLit), null)
checkEvaluation(BitwiseAnd(nullLit, nullLit), null)
+ checkEvaluation(BitwiseAnd(positiveShortLit, negativeShortLit),
+ (positiveShort & negativeShort).toShort)
+ checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt)
+ checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong)
}
test("BitwiseOr") {
@@ -75,6 +87,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(BitwiseOr(nullLit, Literal(1)), null)
checkEvaluation(BitwiseOr(Literal(1), nullLit), null)
checkEvaluation(BitwiseOr(nullLit, nullLit), null)
+ checkEvaluation(BitwiseOr(positiveShortLit, negativeShortLit),
+ (positiveShort | negativeShort).toShort)
+ checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt)
+ checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong)
}
test("BitwiseXor") {
@@ -94,5 +110,10 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(BitwiseXor(nullLit, Literal(1)), null)
checkEvaluation(BitwiseXor(Literal(1), nullLit), null)
checkEvaluation(BitwiseXor(nullLit, nullLit), null)
+
+ checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit),
+ (positiveShort ^ negativeShort).toShort)
+ checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt)
+ checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 4f35b653d73c0..1ad70733eae03 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -242,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
- checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
+ checkEvaluation(cast(123L, DecimalType(3, 1)), null)
- // TODO: Fix the following bug and re-enable it.
- // checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+ checkEvaluation(cast(123L, DecimalType(2, 0)), null)
}
test("cast from boolean") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index f4fbc49677ca3..e323467af5f4a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import scala.math._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.{Row, RandomDataGenerator}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
/**
* Additional tests for code generation.
@@ -53,7 +54,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
// GenerateOrdering agrees with RowOrdering.
(DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
test(s"GenerateOrdering with $dataType") {
- val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
+ val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType))
val genOrdering = GenerateOrdering.generate(
BoundReference(0, dataType, nullable = true).asc ::
BoundReference(1, dataType, nullable = true).asc :: Nil)
@@ -86,11 +87,51 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
val plan = GenerateMutableProjection.generate(expressions)()
- val actual = plan(new GenericMutableRow(length)).toSeq
+ val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq.fill(length)(true)
if (!checkResult(actual, expected)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
}
}
+
+ test("test generated safe and unsafe projection") {
+ val schema = new StructType(Array(
+ StructField("a", StringType, true),
+ StructField("b", IntegerType, true),
+ StructField("c", new StructType(Array(
+ StructField("aa", StringType, true),
+ StructField("bb", IntegerType, true)
+ )), true),
+ StructField("d", new StructType(Array(
+ StructField("a", new StructType(Array(
+ StructField("b", StringType, true),
+ StructField("", IntegerType, true)
+ )), true)
+ )), true)
+ ))
+ val row = Row("a", 1, Row("b", 2), Row(Row("c", 3)))
+ val lit = Literal.create(row, schema)
+ val internalRow = lit.value.asInstanceOf[InternalRow]
+
+ val unsafeProj = UnsafeProjection.create(schema)
+ val unsafeRow: UnsafeRow = unsafeProj(internalRow)
+ assert(unsafeRow.getUTF8String(0) === UTF8String.fromString("a"))
+ assert(unsafeRow.getInt(1) === 1)
+ assert(unsafeRow.getStruct(2, 2).getUTF8String(0) === UTF8String.fromString("b"))
+ assert(unsafeRow.getStruct(2, 2).getInt(1) === 2)
+ assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getUTF8String(0) ===
+ UTF8String.fromString("c"))
+ assert(unsafeRow.getStruct(3, 1).getStruct(0, 2).getInt(1) === 3)
+
+ val fromUnsafe = FromUnsafeProjection(schema)
+ val internalRow2 = fromUnsafe(unsafeRow)
+ assert(internalRow === internalRow2)
+
+ // update unsafeRow should not affect internalRow2
+ unsafeRow.setInt(1, 10)
+ unsafeRow.getStruct(2, 2).setInt(1, 10)
+ unsafeRow.getStruct(3, 1).getStruct(0, 2).setInt(1, 4)
+ assert(internalRow === internalRow2)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 28c41b57169f9..95f0e38212a1a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -43,4 +43,41 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
+
+ test("Sort Array") {
+ val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
+ val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
+
+ checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
+ checkEvaluation(new SortArray(a1), Seq[Integer]())
+ checkEvaluation(new SortArray(a2), Seq("a", "b"))
+ checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
+ checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
+ checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
+ checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
+ checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
+ checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
+ checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
+
+ checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
+ }
+
+ test("Array contains") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a2 = Literal.create(Seq(null), ArrayType(LongType))
+
+ checkEvaluation(ArrayContains(a0, Literal(1)), true)
+ checkEvaluation(ArrayContains(a0, Literal(0)), false)
+ checkEvaluation(ArrayContains(a0, Literal(null)), false)
+
+ checkEvaluation(ArrayContains(a1, Literal("")), true)
+ checkEvaluation(ArrayContains(a1, Literal(null)), false)
+
+ checkEvaluation(ArrayContains(a2, Literal(null)), false)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 3fa246b69d1f1..e60990aeb423f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -171,8 +171,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
test("error message of ExtractValue") {
val structType = StructType(StructField("a", StringType, true) :: Nil)
- val arrayStructType = ArrayType(structType)
- val arrayType = ArrayType(StringType)
val otherType = StringType
def checkErrorMessage(
@@ -189,8 +187,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
- checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
- checkErrorMessage(arrayType, StringType, "Array index should be integral type")
checkErrorMessage(otherType, StringType, "Can't extract value from")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index fd1d6c1d25497..f9b73f1a75e73 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -17,17 +17,19 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.{Timestamp, Date}
+import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ import IntegralLiteralTestUtils._
+
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val sdfDate = new SimpleDateFormat("yyyy-MM-dd")
val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime)
@@ -48,15 +50,13 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("DayOfYear") {
val sdfDay = new SimpleDateFormat("D")
- (1998 to 2002).foreach { y =>
- (0 to 3).foreach { m =>
- (0 to 5).foreach { i =>
- val c = Calendar.getInstance()
- c.set(y, m, 28, 0, 0, 0)
- c.add(Calendar.DATE, i)
- checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
- sdfDay.format(c.getTime).toInt)
- }
+ (0 to 3).foreach { m =>
+ (0 to 5).foreach { i =>
+ val c = Calendar.getInstance()
+ c.set(2000, m, 28, 0, 0, 0)
+ c.add(Calendar.DATE, i)
+ checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))),
+ sdfDay.format(c.getTime).toInt)
}
}
checkEvaluation(DayOfYear(Literal.create(null, DateType)), null)
@@ -214,6 +214,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null)
checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
+ checkEvaluation(
+ DateAdd(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 49627)
+ checkEvaluation(
+ DateAdd(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -15910)
}
test("date_sub") {
@@ -228,6 +232,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
null)
checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
+ checkEvaluation(
+ DateSub(Literal(Date.valueOf("2016-02-28")), positiveIntLit), -15909)
+ checkEvaluation(
+ DateSub(Literal(Date.valueOf("2016-02-28")), negativeIntLit), 49628)
}
test("time_add") {
@@ -282,6 +290,12 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null)
checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)),
null)
+ checkEvaluation(
+ AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498)
+ checkEvaluation(
+ AddMonths(Literal(Date.valueOf("2016-02-28")), positiveIntLit), 1014213)
+ checkEvaluation(
+ AddMonths(Literal(Date.valueOf("2016-02-28")), negativeIntLit), -980528)
}
test("months_between") {
@@ -351,6 +365,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
}
+ test("function to_date") {
+ checkEvaluation(
+ ToDate(Literal(Date.valueOf("2015-07-22"))),
+ DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22")))
+ checkEvaluation(ToDate(Literal.create(null, DateType)), null)
+ }
+
+ test("function trunc") {
+ def testTrunc(input: Date, fmt: String, expected: Date): Unit = {
+ checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)),
+ expected)
+ checkEvaluation(
+ TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)),
+ expected)
+ }
+ val date = Date.valueOf("2015-07-22")
+ Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt =>
+ testTrunc(date, fmt, Date.valueOf("2015-01-01"))
+ }
+ Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt =>
+ testTrunc(date, fmt, Date.valueOf("2015-07-01"))
+ }
+ testTrunc(date, "DD", null)
+ testTrunc(date, null, null)
+ testTrunc(null, "MON", null)
+ testTrunc(null, null, null)
+ }
+
test("from_unixtime") {
val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS"
@@ -406,4 +448,57 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null)
}
+ test("datediff") {
+ checkEvaluation(
+ DateDiff(Literal(Date.valueOf("2015-07-24")), Literal(Date.valueOf("2015-07-21"))), 3)
+ checkEvaluation(
+ DateDiff(Literal(Date.valueOf("2015-07-21")), Literal(Date.valueOf("2015-07-24"))), -3)
+ checkEvaluation(DateDiff(Literal.create(null, DateType), Literal(Date.valueOf("2015-07-24"))),
+ null)
+ checkEvaluation(DateDiff(Literal(Date.valueOf("2015-07-24")), Literal.create(null, DateType)),
+ null)
+ checkEvaluation(
+ DateDiff(Literal.create(null, DateType), Literal.create(null, DateType)),
+ null)
+ }
+
+ test("to_utc_timestamp") {
+ def test(t: String, tz: String, expected: String): Unit = {
+ checkEvaluation(
+ ToUTCTimestamp(
+ Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType),
+ Literal.create(tz, StringType)),
+ if (expected != null) Timestamp.valueOf(expected) else null)
+ checkEvaluation(
+ ToUTCTimestamp(
+ Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType),
+ NonFoldableLiteral.create(tz, StringType)),
+ if (expected != null) Timestamp.valueOf(expected) else null)
+ }
+ test("2015-07-24 00:00:00", "PST", "2015-07-24 07:00:00")
+ test("2015-01-24 00:00:00", "PST", "2015-01-24 08:00:00")
+ test(null, "UTC", null)
+ test("2015-07-24 00:00:00", null, null)
+ test(null, null, null)
+ }
+
+ test("from_utc_timestamp") {
+ def test(t: String, tz: String, expected: String): Unit = {
+ checkEvaluation(
+ FromUTCTimestamp(
+ Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType),
+ Literal.create(tz, StringType)),
+ if (expected != null) Timestamp.valueOf(expected) else null)
+ checkEvaluation(
+ FromUTCTimestamp(
+ Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType),
+ NonFoldableLiteral.create(tz, StringType)),
+ if (expected != null) Timestamp.valueOf(expected) else null)
+ }
+ test("2015-07-24 00:00:00", "PST", "2015-07-23 17:00:00")
+ test("2015-01-24 00:00:00", "PST", "2015-01-23 16:00:00")
+ test(null, "UTC", null)
+ test("2015-07-24 00:00:00", null, null)
+ test(null, null, null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 3c05e5c3b833c..a41185b4d8754 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -18,11 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.scalactic.TripleEqualsSupport.Spread
-import org.scalatest.Matchers._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntegralLiteralTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntegralLiteralTestUtils.scala
new file mode 100644
index 0000000000000..2e5a121f4ec56
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntegralLiteralTestUtils.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+/**
+ * Utilities to make sure we pass the proper numeric ranges
+ */
+object IntegralLiteralTestUtils {
+
+ val positiveShort: Short = (Byte.MaxValue + 1).toShort
+ val negativeShort: Short = (Byte.MinValue - 1).toShort
+
+ val positiveShortLit: Literal = Literal(positiveShort)
+ val negativeShortLit: Literal = Literal(negativeShort)
+
+ val positiveInt: Int = Short.MaxValue + 1
+ val negativeInt: Int = Short.MinValue - 1
+
+ val positiveIntLit: Literal = Literal(positiveInt)
+ val negativeIntLit: Literal = Literal(negativeInt)
+
+ val positiveLong: Long = Int.MaxValue + 1L
+ val negativeLong: Long = Int.MinValue - 1L
+
+ val positiveLongLit: Literal = Literal(positiveLong)
+ val negativeLongLit: Literal = Literal(negativeLong)
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
new file mode 100644
index 0000000000000..4addbaf0cbce7
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+
+class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ val json =
+ """
+ |{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],
+ |"basket":[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]],"book":[{"author":"Nigel Rees",
+ |"title":"Sayings of the Century","category":"reference","price":8.95},
+ |{"author":"Herman Melville","title":"Moby Dick","category":"fiction","price":8.99,
+ |"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":"The Lord of the Rings",
+ |"category":"fiction","reader":[{"age":25,"name":"bob"},{"age":26,"name":"jack"}],
+ |"price":22.99,"isbn":"0-395-19395-8"}],"bicycle":{"price":19.95,"color":"red"}},
+ |"email":"amy@only_for_json_udf_test.net","owner":"amy","zip code":"94025",
+ |"fb:testid":"1234"}
+ |""".stripMargin
+
+ test("$.store.bicycle") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.bicycle")),
+ """{"price":19.95,"color":"red"}""")
+ }
+
+ test("$.store.book") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book")),
+ """[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference",
+ |"price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction",
+ |"price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":
+ |"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},
+ |{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}]
+ |""".stripMargin.replace("\n", ""))
+ }
+
+ test("$.store.book[0]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[0]")),
+ """{"author":"Nigel Rees","title":"Sayings of the Century",
+ |"category":"reference","price":8.95}""".stripMargin.replace("\n", ""))
+ }
+
+ test("$.store.book[*]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[*]")),
+ """[{"author":"Nigel Rees","title":"Sayings of the Century","category":"reference",
+ |"price":8.95},{"author":"Herman Melville","title":"Moby Dick","category":"fiction",
+ |"price":8.99,"isbn":"0-553-21311-3"},{"author":"J. R. R. Tolkien","title":
+ |"The Lord of the Rings","category":"fiction","reader":[{"age":25,"name":"bob"},
+ |{"age":26,"name":"jack"}],"price":22.99,"isbn":"0-395-19395-8"}]
+ |""".stripMargin.replace("\n", ""))
+ }
+
+ test("$") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$")),
+ json.replace("\n", ""))
+ }
+
+ test("$.store.book[0].category") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[0].category")),
+ "reference")
+ }
+
+ test("$.store.book[*].category") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[*].category")),
+ """["reference","fiction","fiction"]""")
+ }
+
+ test("$.store.book[*].isbn") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[*].isbn")),
+ """["0-553-21311-3","0-395-19395-8"]""")
+ }
+
+ test("$.store.book[*].reader") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[*].reader")),
+ """[{"age":25,"name":"bob"},{"age":26,"name":"jack"}]""")
+ }
+
+ test("$.store.basket[0][1]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[0][1]")),
+ "2")
+ }
+
+ test("$.store.basket[*]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[*]")),
+ """[[1,2,{"b":"y","a":"x"}],[3,4],[5,6]]""")
+ }
+
+ test("$.store.basket[*][0]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[*][0]")),
+ "[1,3,5]")
+ }
+
+ test("$.store.basket[0][*]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[0][*]")),
+ """[1,2,{"b":"y","a":"x"}]""")
+ }
+
+ test("$.store.basket[*][*]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[*][*]")),
+ """[1,2,{"b":"y","a":"x"},3,4,5,6]""")
+ }
+
+ test("$.store.basket[0][2].b") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[0][2].b")),
+ "y")
+ }
+
+ test("$.store.basket[0][*].b") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[0][*].b")),
+ """["y"]""")
+ }
+
+ test("$.zip code") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.zip code")),
+ "94025")
+ }
+
+ test("$.fb:testid") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.fb:testid")),
+ "1234")
+ }
+
+ test("preserve newlines") {
+ checkEvaluation(
+ GetJsonObject(Literal("""{"a":"b\nc"}"""), Literal("$.a")),
+ "b\nc")
+ }
+
+ test("escape") {
+ checkEvaluation(
+ GetJsonObject(Literal("""{"a":"b\"c"}"""), Literal("$.a")),
+ "b\"c")
+ }
+
+ test("$.non_exist_key") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.non_exist_key")),
+ null)
+ }
+
+ test("$..no_recursive") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$..no_recursive")),
+ null)
+ }
+
+ test("$.store.book[10]") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[10]")),
+ null)
+ }
+
+ test("$.store.book[0].non_exist_key") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.book[0].non_exist_key")),
+ null)
+ }
+
+ test("$.store.basket[*].non_exist_key") {
+ checkEvaluation(
+ GetJsonObject(Literal(json), Literal("$.store.basket[*].non_exist_key")),
+ null)
+ }
+
+ test("non foldable literal") {
+ checkEvaluation(
+ GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")),
+ "1234")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 9fcb548af6bbb..033792eee6c0f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.types._
class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ import IntegralLiteralTestUtils._
+
/**
* Used for testing leaf math expressions.
*
@@ -293,6 +295,9 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row)
checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row)
checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row)
+
+ checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong))
+ checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong))
}
test("log2") {
@@ -324,6 +329,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong)
checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong)
+
+ checkEvaluation(ShiftLeft(positiveIntLit, positiveIntLit), positiveInt << positiveInt)
+ checkEvaluation(ShiftLeft(positiveIntLit, negativeIntLit), positiveInt << negativeInt)
+ checkEvaluation(ShiftLeft(negativeIntLit, positiveIntLit), negativeInt << positiveInt)
+ checkEvaluation(ShiftLeft(negativeIntLit, negativeIntLit), negativeInt << negativeInt)
+ checkEvaluation(ShiftLeft(positiveLongLit, positiveIntLit), positiveLong << positiveInt)
+ checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt)
+ checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt)
+ checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt)
}
test("shift right") {
@@ -335,6 +349,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong)
checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong)
+
+ checkEvaluation(ShiftRight(positiveIntLit, positiveIntLit), positiveInt >> positiveInt)
+ checkEvaluation(ShiftRight(positiveIntLit, negativeIntLit), positiveInt >> negativeInt)
+ checkEvaluation(ShiftRight(negativeIntLit, positiveIntLit), negativeInt >> positiveInt)
+ checkEvaluation(ShiftRight(negativeIntLit, negativeIntLit), negativeInt >> negativeInt)
+ checkEvaluation(ShiftRight(positiveLongLit, positiveIntLit), positiveLong >> positiveInt)
+ checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt)
+ checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt)
+ checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt)
}
test("shift right unsigned") {
@@ -346,6 +369,23 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong)
checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L)
+
+ checkEvaluation(ShiftRightUnsigned(positiveIntLit, positiveIntLit),
+ positiveInt >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(positiveIntLit, negativeIntLit),
+ positiveInt >>> negativeInt)
+ checkEvaluation(ShiftRightUnsigned(negativeIntLit, positiveIntLit),
+ negativeInt >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(negativeIntLit, negativeIntLit),
+ negativeInt >>> negativeInt)
+ checkEvaluation(ShiftRightUnsigned(positiveLongLit, positiveIntLit),
+ positiveLong >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(positiveLongLit, negativeIntLit),
+ positiveLong >>> negativeInt)
+ checkEvaluation(ShiftRightUnsigned(negativeLongLit, positiveIntLit),
+ negativeLong >>> positiveInt)
+ checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit),
+ negativeLong >>> negativeInt)
}
test("hex") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
index 0559fb80e7fce..31ecf4a9e810a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala
@@ -47,4 +47,8 @@ object NonFoldableLiteral {
val lit = Literal(value)
NonFoldableLiteral(lit.value, lit.dataType)
}
+ def create(value: Any, dataType: DataType): NonFoldableLiteral = {
+ val lit = Literal.create(value, dataType)
+ NonFoldableLiteral(lit.value, lit.dataType)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 0bc2812a5dc83..7beef71845e43 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.types._
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
+
+ val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+ LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+ primitiveTypes.map { t =>
+ val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val inputData = Seq.fill(10) {
+ val value = dataGen.apply()
+ value match {
+ case d: Double if d.isNaN => 0.0d
+ case f: Float if f.isNaN => 0.0f
+ case _ => value
+ }
+ }
+ val input = inputData.map(Literal(_))
+ checkEvaluation(In(input(0), input.slice(1, 10)),
+ inputData.slice(1, 10).contains(inputData(0)))
+ }
}
test("INSET") {
@@ -134,62 +152,79 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), false)
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
+
+ val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+ LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+ primitiveTypes.map { t =>
+ val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val inputData = Seq.fill(10) {
+ val value = dataGen.apply()
+ value match {
+ case d: Double if d.isNaN => 0.0d
+ case f: Float if f.isNaN => 0.0f
+ case _ => value
+ }
+ }
+ val input = inputData.map(Literal(_))
+ checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
+ inputData.slice(1, 10).contains(inputData(0)))
+ }
}
- private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
+ private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
private val largeValues =
- Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
+ Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))
private val equalValues1 =
- Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+ Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
private val equalValues2 =
- Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+ Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
- test("BinaryComparison: <") {
+ test("BinaryComparison: lessThan") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) < largeValues(i), true)
- checkEvaluation(equalValues1(i) < equalValues2(i), false)
- checkEvaluation(largeValues(i) < smallValues(i), false)
+ checkEvaluation(LessThan(smallValues(i), largeValues(i)), true)
+ checkEvaluation(LessThan(equalValues1(i), equalValues2(i)), false)
+ checkEvaluation(LessThan(largeValues(i), smallValues(i)), false)
}
}
- test("BinaryComparison: <=") {
+ test("BinaryComparison: LessThanOrEqual") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) <= largeValues(i), true)
- checkEvaluation(equalValues1(i) <= equalValues2(i), true)
- checkEvaluation(largeValues(i) <= smallValues(i), false)
+ checkEvaluation(LessThanOrEqual(smallValues(i), largeValues(i)), true)
+ checkEvaluation(LessThanOrEqual(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(LessThanOrEqual(largeValues(i), smallValues(i)), false)
}
}
- test("BinaryComparison: >") {
+ test("BinaryComparison: GreaterThan") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) > largeValues(i), false)
- checkEvaluation(equalValues1(i) > equalValues2(i), false)
- checkEvaluation(largeValues(i) > smallValues(i), true)
+ checkEvaluation(GreaterThan(smallValues(i), largeValues(i)), false)
+ checkEvaluation(GreaterThan(equalValues1(i), equalValues2(i)), false)
+ checkEvaluation(GreaterThan(largeValues(i), smallValues(i)), true)
}
}
- test("BinaryComparison: >=") {
+ test("BinaryComparison: GreaterThanOrEqual") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) >= largeValues(i), false)
- checkEvaluation(equalValues1(i) >= equalValues2(i), true)
- checkEvaluation(largeValues(i) >= smallValues(i), true)
+ checkEvaluation(GreaterThanOrEqual(smallValues(i), largeValues(i)), false)
+ checkEvaluation(GreaterThanOrEqual(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(GreaterThanOrEqual(largeValues(i), smallValues(i)), true)
}
}
- test("BinaryComparison: ===") {
+ test("BinaryComparison: EqualTo") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) === largeValues(i), false)
- checkEvaluation(equalValues1(i) === equalValues2(i), true)
- checkEvaluation(largeValues(i) === smallValues(i), false)
+ checkEvaluation(EqualTo(smallValues(i), largeValues(i)), false)
+ checkEvaluation(EqualTo(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(EqualTo(largeValues(i), smallValues(i)), false)
}
}
- test("BinaryComparison: <=>") {
+ test("BinaryComparison: EqualNullSafe") {
for (i <- 0 until smallValues.length) {
- checkEvaluation(smallValues(i) <=> largeValues(i), false)
- checkEvaluation(equalValues1(i) <=> equalValues2(i), true)
- checkEvaluation(largeValues(i) <=> smallValues(i), false)
+ checkEvaluation(EqualNullSafe(smallValues(i), largeValues(i)), false)
+ checkEvaluation(EqualNullSafe(equalValues1(i), equalValues2(i)), true)
+ checkEvaluation(EqualNullSafe(largeValues(i), smallValues(i)), false)
}
}
@@ -209,8 +244,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
nullTest(GreaterThanOrEqual)
nullTest(EqualTo)
- checkEvaluation(normalInt <=> nullInt, false)
- checkEvaluation(nullInt <=> normalInt, false)
- checkEvaluation(nullInt <=> nullInt, true)
+ checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
+ checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
+ checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 07b952531ec2e..426dc272471ae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -185,12 +185,65 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(s.substr(0), "example", row)
checkEvaluation(s.substring(0, 2), "ex", row)
checkEvaluation(s.substring(0), "example", row)
+
+ val bytes = Array[Byte](1, 2, 3, 4)
+ checkEvaluation(Substring(bytes, 0, 2), Array[Byte](1, 2))
+ checkEvaluation(Substring(bytes, 1, 2), Array[Byte](1, 2))
+ checkEvaluation(Substring(bytes, 2, 2), Array[Byte](2, 3))
+ checkEvaluation(Substring(bytes, 3, 2), Array[Byte](3, 4))
+ checkEvaluation(Substring(bytes, 4, 2), Array[Byte](4))
+ checkEvaluation(Substring(bytes, 8, 2), Array[Byte]())
+ checkEvaluation(Substring(bytes, -1, 2), Array[Byte](4))
+ checkEvaluation(Substring(bytes, -2, 2), Array[Byte](3, 4))
+ checkEvaluation(Substring(bytes, -3, 2), Array[Byte](2, 3))
+ checkEvaluation(Substring(bytes, -4, 2), Array[Byte](1, 2))
+ checkEvaluation(Substring(bytes, -5, 2), Array[Byte](1))
+ checkEvaluation(Substring(bytes, -8, 2), Array[Byte]())
+ }
+
+ test("string substring_index function") {
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(3)), "www.apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(2)), "www.apache")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(1)), "www")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(0)), "")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-3)), "www.apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-2)), "apache.org")
+ checkEvaluation(
+ SubstringIndex(Literal("www.apache.org"), Literal("."), Literal(-1)), "org")
+ checkEvaluation(
+ SubstringIndex(Literal(""), Literal("."), Literal(-2)), "")
+ checkEvaluation(
+ SubstringIndex(Literal.create(null, StringType), Literal("."), Literal(-2)), null)
+ checkEvaluation(SubstringIndex(
+ Literal("www.apache.org"), Literal.create(null, StringType), Literal(-2)), null)
+ // non ascii chars
+ // scalastyle:off
+ checkEvaluation(
+ SubstringIndex(Literal("大千世界大千世界"), Literal( "千"), Literal(2)), "大千世界大")
+ // scalastyle:on
+ checkEvaluation(
+ SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache")
}
test("LIKE literal Regular Expression") {
checkEvaluation(Literal.create(null, StringType).like("a"), null)
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null)
+ checkEvaluation(
+ Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true)
+ checkEvaluation(
+ Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null)
+ checkEvaluation(
+ Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null)
+ checkEvaluation(
+ Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null)
+
checkEvaluation("abdef" like "abdef", true)
checkEvaluation("a_%b" like "a\\__b", true)
checkEvaluation("addb" like "a_%b", true)
@@ -232,6 +285,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, StringType) rlike "abdef", null)
checkEvaluation("abdef" rlike Literal.create(null, StringType), null)
checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null)
+ checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true)
+ checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null)
+ checkEvaluation(
+ Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null)
+ checkEvaluation(
+ Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null)
+
checkEvaluation("abdef" rlike "abdef", true)
checkEvaluation("abbbbc" rlike "a.*c", true)
@@ -317,6 +377,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
+ test("initcap unit test") {
+ checkEvaluation(InitCap(Literal.create(null, StringType)), null)
+ checkEvaluation(InitCap(Literal("a b")), "A B")
+ checkEvaluation(InitCap(Literal(" a")), " A")
+ checkEvaluation(InitCap(Literal("the test")), "The Test")
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ checkEvaluation(InitCap(Literal("世界")), "世界")
+ // scalastyle:on
+ }
+
+
test("Levenshtein distance") {
checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
@@ -331,6 +403,48 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}
+ test("soundex unit test") {
+ checkEvaluation(SoundEx(Literal("ZIN")), "Z500")
+ checkEvaluation(SoundEx(Literal("SU")), "S000")
+ checkEvaluation(SoundEx(Literal("")), "")
+ checkEvaluation(SoundEx(Literal.create(null, StringType)), null)
+
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ checkEvaluation(SoundEx(Literal("测试")), "测试")
+ checkEvaluation(SoundEx(Literal("Tschüss")), "T220")
+ // scalastyle:on
+ checkEvaluation(SoundEx(Literal("zZ")), "Z000", create_row("s8"))
+ checkEvaluation(SoundEx(Literal("RAGSSEEESSSVEEWE")), "R221")
+ checkEvaluation(SoundEx(Literal("Ashcraft")), "A261")
+ checkEvaluation(SoundEx(Literal("Aswcraft")), "A261")
+ checkEvaluation(SoundEx(Literal("Tymczak")), "T522")
+ checkEvaluation(SoundEx(Literal("Pfister")), "P236")
+ checkEvaluation(SoundEx(Literal("Miller")), "M460")
+ checkEvaluation(SoundEx(Literal("Peterson")), "P362")
+ checkEvaluation(SoundEx(Literal("Peters")), "P362")
+ checkEvaluation(SoundEx(Literal("Auerbach")), "A612")
+ checkEvaluation(SoundEx(Literal("Uhrbach")), "U612")
+ checkEvaluation(SoundEx(Literal("Moskowitz")), "M232")
+ checkEvaluation(SoundEx(Literal("Moskovitz")), "M213")
+ checkEvaluation(SoundEx(Literal("relyheewsgeessg")), "R422")
+ checkEvaluation(SoundEx(Literal("!!")), "!!")
+ }
+
+ test("translate") {
+ checkEvaluation(
+ StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")
+ checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate")
+ checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")), "asae")
+ // test for multiple mapping
+ checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd")
+ checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd")
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b")
+ // scalastyle:on
+ }
+
test("TRIM/LTRIM/RTRIM") {
val s = 'a.string.at(0)
checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef "))
@@ -575,4 +689,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
}
+
+ test("find in set") {
+ checkEvaluation(
+ FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null)
+ checkEvaluation(FindInSet(Literal("ab"), Literal.create(null, StringType)), null)
+ checkEvaluation(FindInSet(Literal.create(null, StringType), Literal("abc,b,ab,c,def")), null)
+ checkEvaluation(FindInSet(Literal("ab"), Literal("abc,b,ab,c,def")), 3)
+ checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
+ checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
deleted file mode 100644
index 6a907290f2dbe..0000000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.collection.JavaConverters._
-import scala.util.Random
-
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
-import org.apache.spark.unsafe.types.UTF8String
-
-
-class UnsafeFixedWidthAggregationMapSuite
- extends SparkFunSuite
- with Matchers
- with BeforeAndAfterEach {
-
- import UnsafeFixedWidthAggregationMap._
-
- private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
- private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
- private def emptyAggregationBuffer: InternalRow = InternalRow(0)
- private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
-
- private var memoryManager: TaskMemoryManager = null
-
- override def beforeEach(): Unit = {
- memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- }
-
- override def afterEach(): Unit = {
- if (memoryManager != null) {
- memoryManager.cleanUpAllAllocatedMemory()
- memoryManager = null
- }
- }
-
- test("supported schemas") {
- assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
- assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
-
- assert(
- !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
- assert(
- !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
- }
-
- test("empty map") {
- val map = new UnsafeFixedWidthAggregationMap(
- emptyAggregationBuffer,
- aggBufferSchema,
- groupKeySchema,
- memoryManager,
- 1024, // initial capacity,
- PAGE_SIZE_BYTES,
- false // disable perf metrics
- )
- assert(!map.iterator().hasNext)
- map.free()
- }
-
- test("updating values for a single key") {
- val map = new UnsafeFixedWidthAggregationMap(
- emptyAggregationBuffer,
- aggBufferSchema,
- groupKeySchema,
- memoryManager,
- 1024, // initial capacity
- PAGE_SIZE_BYTES,
- false // disable perf metrics
- )
- val groupKey = InternalRow(UTF8String.fromString("cats"))
-
- // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
- map.getAggregationBuffer(groupKey)
- val iter = map.iterator()
- val entry = iter.next()
- assert(!iter.hasNext)
- entry.key.getString(0) should be ("cats")
- entry.value.getInt(0) should be (0)
-
- // Modifications to rows retrieved from the map should update the values in the map
- entry.value.setInt(0, 42)
- map.getAggregationBuffer(groupKey).getInt(0) should be (42)
-
- map.free()
- }
-
- test("inserting large random keys") {
- val map = new UnsafeFixedWidthAggregationMap(
- emptyAggregationBuffer,
- aggBufferSchema,
- groupKeySchema,
- memoryManager,
- 128, // initial capacity
- PAGE_SIZE_BYTES,
- false // disable perf metrics
- )
- val rand = new Random(42)
- val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
- groupKeys.foreach { keyString =>
- map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
- }
- val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
- entry.key.getString(0)
- }.toSet
- seenKeys.size should be (groupKeys.size)
- seenKeys should be (groupKeys)
-
- map.free()
- }
-
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index b7bc17f89e82f..8c72203193630 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.unsafe.types.UTF8String
class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
+ private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size)
+
test("basic conversion with only primitive types") {
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
val converter = UnsafeProjection.create(fieldTypes)
@@ -46,7 +48,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
- // We can copy UnsafeRows as long as they don't reference ObjectPools
val unsafeRowCopy = unsafeRow.copy()
assert(unsafeRowCopy.getLong(0) === 0)
assert(unsafeRowCopy.getLong(1) === 1)
@@ -74,8 +75,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val unsafeRow: UnsafeRow = converter.apply(row)
assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
+ roundedSize("Hello".getBytes.length) +
+ roundedSize("World".getBytes.length))
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
@@ -88,13 +89,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
- row.setString(1, "Hello")
+ row.update(1, UTF8String.fromString("Hello"))
row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")))
row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
val unsafeRow: UnsafeRow = converter.apply(row)
- assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) +
- ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
+ assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + roundedSize("Hello".getBytes.length))
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
@@ -122,8 +122,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
FloatType,
DoubleType,
StringType,
- BinaryType
- // DecimalType.Default,
+ BinaryType,
+ DecimalType.USER_DEFAULT,
+ DecimalType.SYSTEM_DEFAULT
// ArrayType(IntegerType)
)
val converter = UnsafeProjection.create(fieldTypes)
@@ -150,7 +151,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(createdFromNull.getDouble(7) === 0.0d)
assert(createdFromNull.getUTF8String(8) === null)
assert(createdFromNull.getBinary(9) === null)
- // assert(createdFromNull.get(10) === null)
+ assert(createdFromNull.getDecimal(10, 10, 0) === null)
+ assert(createdFromNull.getDecimal(11, 38, 18) === null)
// assert(createdFromNull.get(11) === null)
// If we have an UnsafeRow with columns that are initially non-null and we null out those
@@ -168,11 +170,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r.setDouble(7, 700)
r.update(8, UTF8String.fromString("hello"))
r.update(9, "world".getBytes)
- // r.update(10, Decimal(10))
+ r.setDecimal(10, Decimal(10), 10)
+ r.setDecimal(11, Decimal(10.00, 38, 18), 38)
// r.update(11, Array(11))
r
}
+ // todo: we reuse the UnsafeRow in projection, so these tests are meaningless.
val setToNullAfterCreation = converter.apply(rowWithNoNullColumns)
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
@@ -184,11 +188,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9))
- // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+ assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
+ rowWithNoNullColumns.getDecimal(10, 10, 0))
+ assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
+ rowWithNoNullColumns.getDecimal(11, 38, 18))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- fieldTypes.indices) {
- setToNullAfterCreation.setNullAt(i)
+ // Cann't call setNullAt() on DecimalType
+ if (i == 11) {
+ setToNullAfterCreation.setDecimal(11, null, 38)
+ } else {
+ setToNullAfterCreation.setNullAt(i)
+ }
}
// There are some garbage left in the var-length area
assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
@@ -203,7 +215,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
setToNullAfterCreation.setDouble(7, 700)
// setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
// setToNullAfterCreation.update(9, "world".getBytes)
- // setToNullAfterCreation.update(10, Decimal(10))
+ setToNullAfterCreation.setDecimal(10, Decimal(10), 10)
+ setToNullAfterCreation.setDecimal(11, Decimal(10.00, 38, 18), 38)
// setToNullAfterCreation.update(11, Array(11))
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
@@ -216,7 +229,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
// assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
// assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
- // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+ assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
+ rowWithNoNullColumns.getDecimal(10, 10, 0))
+ assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
+ rowWithNoNullColumns.getDecimal(11, 38, 18))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
@@ -234,4 +250,108 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val converter = UnsafeProjection.create(fieldTypes)
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
}
+
+ test("basic conversion with array type") {
+ val fieldTypes: Array[DataType] = Array(
+ ArrayType(LongType),
+ ArrayType(ArrayType(LongType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
+
+ val array1 = new GenericArrayData(Array[Any](1L, 2L))
+ val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L))))
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, array1)
+ row.update(1, array2)
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields() == 2)
+
+ val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
+ assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2)
+ assert(unsafeArray1.numElements() == 2)
+ assert(unsafeArray1.getLong(0) == 1L)
+ assert(unsafeArray1.getLong(1) == 2L)
+
+ val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData]
+ assert(unsafeArray2.numElements() == 1)
+
+ val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData]
+ assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2)
+ assert(nestedArray.numElements() == 2)
+ assert(nestedArray.getLong(0) == 3L)
+ assert(nestedArray.getLong(1) == 4L)
+
+ assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+
+ val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes)
+ val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes)
+ assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
+ }
+
+ test("basic conversion with map type") {
+ def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+
+ def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = {
+ val numElements = keys.length
+ assert(map.numElements() == numElements)
+
+ val keyArray = map.keys
+ assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements)
+ assert(keyArray.numElements() == numElements)
+ keys.zipWithIndex.foreach { case (key, i) =>
+ assert(keyArray.getInt(i) == key)
+ }
+
+ val valueArray = map.values
+ assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements)
+ assert(valueArray.numElements() == numElements)
+ values.zipWithIndex.foreach { case (value, i) =>
+ assert(valueArray.getLong(i) == value)
+ }
+
+ assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ }
+
+ val fieldTypes: Array[DataType] = Array(
+ MapType(IntegerType, LongType),
+ MapType(IntegerType, MapType(IntegerType, LongType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
+
+ val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L))
+
+ val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L))
+ val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap))
+
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, map1)
+ row.update(1, map2)
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields() == 2)
+
+ val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData]
+ testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L))
+
+ val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData]
+ assert(unsafeMap2.numElements() == 1)
+
+ val keyArray = unsafeMap2.keys
+ assert(keyArray.getSizeInBytes == 4 + 4)
+ assert(keyArray.numElements() == 1)
+ assert(keyArray.getInt(0) == 9)
+
+ val valueArray = unsafeMap2.values
+ assert(valueArray.numElements() == 1)
+ val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData]
+ testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L))
+ assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes)
+
+ assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+ val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes)
+ val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes)
+ assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
new file mode 100644
index 0000000000000..796d60032e1a6
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+
+/**
+ * A test suite for the bitset portion of the row concatenation.
+ */
+class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
+
+ test("bitset concat: boundary size 0, 0") {
+ testBitsets(0, 0)
+ }
+
+ test("bitset concat: boundary size 0, 64") {
+ testBitsets(0, 64)
+ }
+
+ test("bitset concat: boundary size 64, 0") {
+ testBitsets(64, 0)
+ }
+
+ test("bitset concat: boundary size 64, 64") {
+ testBitsets(64, 64)
+ }
+
+ test("bitset concat: boundary size 0, 128") {
+ testBitsets(0, 128)
+ }
+
+ test("bitset concat: boundary size 128, 0") {
+ testBitsets(128, 0)
+ }
+
+ test("bitset concat: boundary size 128, 128") {
+ testBitsets(128, 128)
+ }
+
+ test("bitset concat: single word bitsets") {
+ testBitsets(10, 5)
+ }
+
+ test("bitset concat: first bitset larger than a word") {
+ testBitsets(67, 5)
+ }
+
+ test("bitset concat: second bitset larger than a word") {
+ testBitsets(6, 67)
+ }
+
+ test("bitset concat: no reduction in bitset size") {
+ testBitsets(33, 34)
+ }
+
+ test("bitset concat: two words") {
+ testBitsets(120, 95)
+ }
+
+ test("bitset concat: bitset 65, 128") {
+ testBitsets(65, 128)
+ }
+
+ test("bitset concat: randomized tests") {
+ for (i <- 1 until 20) {
+ val numFields1 = Random.nextInt(1000)
+ val numFields2 = Random.nextInt(1000)
+ testBitsetsOnce(numFields1, numFields2)
+ }
+ }
+
+ private def createUnsafeRow(numFields: Int): UnsafeRow = {
+ val row = new UnsafeRow
+ val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
+ // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle.
+ // This way we can test the joiner when the input UnsafeRows are not the entire arrays.
+ val offset = numFields * 8
+ val buf = new Array[Byte](sizeInBytes + offset)
+ row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes)
+ row
+ }
+
+ private def testBitsets(numFields1: Int, numFields2: Int): Unit = {
+ for (i <- 0 until 5) {
+ testBitsetsOnce(numFields1, numFields2)
+ }
+ }
+
+ private def testBitsetsOnce(numFields1: Int, numFields2: Int): Unit = {
+ info(s"num fields: $numFields1 and $numFields2")
+ val schema1 = StructType(Seq.tabulate(numFields1) { i => StructField(s"a_$i", IntegerType) })
+ val schema2 = StructType(Seq.tabulate(numFields2) { i => StructField(s"b_$i", IntegerType) })
+
+ val row1 = createUnsafeRow(numFields1)
+ val row2 = createUnsafeRow(numFields2)
+
+ if (numFields1 > 0) {
+ for (i <- 0 until Random.nextInt(numFields1)) {
+ row1.setNullAt(Random.nextInt(numFields1))
+ }
+ }
+ if (numFields2 > 0) {
+ for (i <- 0 until Random.nextInt(numFields2)) {
+ row2.setNullAt(Random.nextInt(numFields2))
+ }
+ }
+
+ val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
+ val output = concater.join(row1, row2)
+
+ def dumpDebug(): String = {
+ val set1 = Seq.tabulate(numFields1) { i => if (row1.isNullAt(i)) "1" else "0" }
+ val set2 = Seq.tabulate(numFields2) { i => if (row2.isNullAt(i)) "1" else "0" }
+ val out = Seq.tabulate(numFields1 + numFields2) { i => if (output.isNullAt(i)) "1" else "0" }
+
+ s"""
+ |input1: ${set1.mkString}
+ |input2: ${set2.mkString}
+ |output: ${out.mkString}
+ |expect: ${set1.mkString}${set2.mkString}
+ """.stripMargin
+ }
+
+ for (i <- 0 until (numFields1 + numFields2)) {
+ if (i < numFields1) {
+ assert(output.isNullAt(i) === row1.isNullAt(i), dumpDebug())
+ } else {
+ assert(output.isNullAt(i) === row2.isNullAt(i - numFields1), dumpDebug())
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
new file mode 100644
index 0000000000000..59729e7646beb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for [[GenerateUnsafeRowJoiner]].
+ *
+ * There is also a separate [[GenerateUnsafeRowJoinerBitsetSuite]] that tests specifically
+ * concatenation for the bitset portion, since that is the hardest one to get right.
+ */
+class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
+
+ private val fixed = Seq(IntegerType)
+ private val variable = Seq(IntegerType, StringType)
+
+ test("simple fixed width types") {
+ testConcat(0, 0, fixed)
+ testConcat(0, 1, fixed)
+ testConcat(1, 0, fixed)
+ testConcat(64, 0, fixed)
+ testConcat(0, 64, fixed)
+ testConcat(64, 64, fixed)
+ }
+
+ test("randomized fix width types") {
+ for (i <- 0 until 20) {
+ testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
+ }
+ }
+
+ test("simple variable width types") {
+ testConcat(0, 0, variable)
+ testConcat(0, 1, variable)
+ testConcat(1, 0, variable)
+ testConcat(64, 0, variable)
+ testConcat(0, 64, variable)
+ testConcat(64, 64, variable)
+ }
+
+ test("randomized variable width types") {
+ for (i <- 0 until 10) {
+ testConcatOnce(Random.nextInt(100), Random.nextInt(100), variable)
+ }
+ }
+
+ private def testConcat(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]): Unit = {
+ for (i <- 0 until 10) {
+ testConcatOnce(numFields1, numFields2, candidateTypes)
+ }
+ }
+
+ private def testConcatOnce(numFields1: Int, numFields2: Int, candidateTypes: Seq[DataType]) {
+ info(s"schema size $numFields1, $numFields2")
+ val schema1 = RandomDataGenerator.randomSchema(numFields1, candidateTypes)
+ val schema2 = RandomDataGenerator.randomSchema(numFields2, candidateTypes)
+
+ // Create the converters needed to convert from external row to internal row and to UnsafeRows.
+ val internalConverter1 = CatalystTypeConverters.createToCatalystConverter(schema1)
+ val internalConverter2 = CatalystTypeConverters.createToCatalystConverter(schema2)
+ val converter1 = UnsafeProjection.create(schema1)
+ val converter2 = UnsafeProjection.create(schema2)
+
+ // Create the input rows, convert them into UnsafeRows.
+ val extRow1 = RandomDataGenerator.forType(schema1, nullable = false).get.apply()
+ val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
+ val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
+ val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+
+ // Run the joiner.
+ val mergedSchema = StructType(schema1 ++ schema2)
+ val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
+ val output = concater.join(row1, row2)
+
+ // Test everything equals ...
+ for (i <- mergedSchema.indices) {
+ if (i < schema1.size) {
+ assert(output.isNullAt(i) === row1.isNullAt(i))
+ if (!output.isNullAt(i)) {
+ assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
+ }
+ } else {
+ assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
+ if (!output.isNullAt(i)) {
+ assert(output.get(i, mergedSchema(i).dataType) ===
+ row2.get(i - schema1.size, mergedSchema(i).dataType))
+ }
+ }
+ }
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
new file mode 100644
index 0000000000000..8c7ee8720f7bb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A test suite for generated projections
+ */
+class GeneratedProjectionSuite extends SparkFunSuite {
+
+ test("generated projections on wider table") {
+ val N = 1000
+ val wideRow1 = new GenericInternalRow((1 to N).toArray[Any])
+ val schema1 = StructType((1 to N).map(i => StructField("", IntegerType)))
+ val wideRow2 = new GenericInternalRow(
+ (1 to N).map(i => UTF8String.fromString(i.toString)).toArray[Any])
+ val schema2 = StructType((1 to N).map(i => StructField("", StringType)))
+ val joined = new JoinedRow(wideRow1, wideRow2)
+ val joinedSchema = StructType(schema1 ++ schema2)
+ val nested = new JoinedRow(InternalRow(joined, joined), joined)
+ val nestedSchema = StructType(
+ Seq(StructField("", joinedSchema), StructField("", joinedSchema)) ++ joinedSchema)
+
+ // test generated UnsafeProjection
+ val unsafeProj = UnsafeProjection.create(nestedSchema)
+ val unsafe: UnsafeRow = unsafeProj(nested)
+ (0 until N).foreach { i =>
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(i + 1 === unsafe.getInt(i + 2))
+ assert(s === unsafe.getUTF8String(i + 2 + N))
+ assert(i + 1 === unsafe.getStruct(0, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(i + 1 === unsafe.getStruct(1, N * 2).getInt(i))
+ assert(s === unsafe.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated SafeProjection
+ val safeProj = FromUnsafeProjection(nestedSchema)
+ val result = safeProj(unsafe)
+ // Can't compare GenericInternalRow with JoinedRow directly
+ (0 until N).foreach { i =>
+ val r = i + 1
+ val s = UTF8String.fromString((i + 1).toString)
+ assert(r === result.getInt(i + 2))
+ assert(s === result.getUTF8String(i + 2 + N))
+ assert(r === result.getStruct(0, N * 2).getInt(i))
+ assert(s === result.getStruct(0, N * 2).getUTF8String(i + N))
+ assert(r === result.getStruct(1, N * 2).getInt(i))
+ assert(s === result.getStruct(1, N * 2).getUTF8String(i + N))
+ }
+
+ // test generated MutableProjection
+ val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) =>
+ BoundReference(i, f.dataType, true)
+ }
+ val mutableProj = GenerateMutableProjection.generate(exprs)()
+ val row1 = mutableProj(result)
+ assert(result === row1)
+ val row2 = mutableProj(result)
+ assert(result === row2)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index d4916ea8d273a..1877cff1334bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries}
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -88,20 +89,24 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
('a === 'b || 'b > 3 && 'a > 3 && 'a < 5))
}
- private def caseInsensitiveAnalyse(plan: LogicalPlan) =
- AnalysisSuite.caseInsensitiveAnalyzer.execute(plan)
+ private val caseInsensitiveAnalyzer =
+ new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false))
test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
- val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5)))
+ val plan = caseInsensitiveAnalyzer.execute(
+ testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5)))
val actual = Optimize.execute(plan)
- val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5)))
+ val expected = caseInsensitiveAnalyzer.execute(
+ testRelation.where('a > 2 && ('b > 3 || 'b < 5)))
comparePlans(actual, expected)
}
test("(a || b) && (a || c) => a || (b && c) when case insensitive") {
- val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5)))
+ val plan = caseInsensitiveAnalyzer.execute(
+ testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5)))
val actual = Optimize.execute(plan)
- val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
+ val expected = caseInsensitiveAnalyzer.execute(
+ testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
comparePlans(actual, expected)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 1d433275fed2e..6f7b5b9572e22 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
- test("OptimizedIn test: In clause optimized to InSet") {
+ test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
.analyze
+ val optimized = Optimize.execute(originalQuery.analyze)
+ comparePlans(optimized, originalQuery)
+ }
+
+ test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
+ .analyze
+
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
- .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
+ .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
.analyze
comparePlans(optimized, correctAnswer)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
new file mode 100644
index 0000000000000..455a3810c719e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util._
+
+/**
+ * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly
+ * skips sub-trees that have already been marked as analyzed.
+ */
+class LogicalPlanSuite extends SparkFunSuite {
+ private var invocationCount = 0
+ private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
+ case p: Project =>
+ invocationCount += 1
+ p
+ }
+
+ private val testRelation = LocalRelation()
+
+ test("resolveOperator runs on operators") {
+ invocationCount = 0
+ val plan = Project(Nil, testRelation)
+ plan resolveOperators function
+
+ assert(invocationCount === 1)
+ }
+
+ test("resolveOperator runs on operators recursively") {
+ invocationCount = 0
+ val plan = Project(Nil, Project(Nil, testRelation))
+ plan resolveOperators function
+
+ assert(invocationCount === 2)
+ }
+
+ test("resolveOperator skips all ready resolved plans") {
+ invocationCount = 0
+ val plan = Project(Nil, Project(Nil, testRelation))
+ plan.foreach(_.setAnalyzed())
+ plan resolveOperators function
+
+ assert(invocationCount === 0)
+ }
+
+ test("resolveOperator skips partially resolved plans") {
+ invocationCount = 0
+ val plan1 = Project(Nil, testRelation)
+ val plan2 = Project(Nil, plan1)
+ plan1.foreach(_.setAnalyzed())
+ plan2 resolveOperators function
+
+ assert(invocationCount === 1)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
index 60d2bcfe13757..d18fa4df13355 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala
@@ -398,4 +398,26 @@ class DateTimeUtilsSuite extends SparkFunSuite {
c2.set(1996, 2, 31, 0, 0, 0)
assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11)
}
+
+ test("from UTC timestamp") {
+ def test(utc: String, tz: String, expected: String): Unit = {
+ assert(toJavaTimestamp(fromUTCTime(fromJavaTimestamp(Timestamp.valueOf(utc)), tz)).toString
+ === expected)
+ }
+ test("2011-12-25 09:00:00.123456", "UTC", "2011-12-25 09:00:00.123456")
+ test("2011-12-25 09:00:00.123456", "JST", "2011-12-25 18:00:00.123456")
+ test("2011-12-25 09:00:00.123456", "PST", "2011-12-25 01:00:00.123456")
+ test("2011-12-25 09:00:00.123456", "Asia/Shanghai", "2011-12-25 17:00:00.123456")
+ }
+
+ test("to UTC timestamp") {
+ def test(utc: String, tz: String, expected: String): Unit = {
+ assert(toJavaTimestamp(toUTCTime(fromJavaTimestamp(Timestamp.valueOf(utc)), tz)).toString
+ === expected)
+ }
+ test("2011-12-25 09:00:00.123456", "UTC", "2011-12-25 09:00:00.123456")
+ test("2011-12-25 18:00:00.123456", "JST", "2011-12-25 09:00:00.123456")
+ test("2011-12-25 01:00:00.123456", "PST", "2011-12-25 09:00:00.123456")
+ test("2011-12-25 17:00:00.123456", "Asia/Shanghai", "2011-12-25 09:00:00.123456")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
new file mode 100644
index 0000000000000..d6f273f9e568a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.StringUtils._
+
+class StringUtilsSuite extends SparkFunSuite {
+
+ test("escapeLikeRegex") {
+ assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E")
+ assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E")
+ assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E")
+ assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E")
+ assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*")
+ assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
+ assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 88b221cd81d74..706ecd29d1355 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite {
}
}
+ test("existsRecursively") {
+ val struct = StructType(
+ StructField("a", LongType) ::
+ StructField("b", FloatType) :: Nil)
+ assert(struct.existsRecursively(_.isInstanceOf[LongType]))
+ assert(struct.existsRecursively(_.isInstanceOf[StructType]))
+ assert(!struct.existsRecursively(_.isInstanceOf[IntegerType]))
+
+ val mapType = MapType(struct, StringType)
+ assert(mapType.existsRecursively(_.isInstanceOf[LongType]))
+ assert(mapType.existsRecursively(_.isInstanceOf[StructType]))
+ assert(mapType.existsRecursively(_.isInstanceOf[StringType]))
+ assert(mapType.existsRecursively(_.isInstanceOf[MapType]))
+ assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType]))
+
+ val arrayType = ArrayType(mapType)
+ assert(arrayType.existsRecursively(_.isInstanceOf[LongType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[StructType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[StringType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[MapType]))
+ assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType]))
+ assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
+ }
+
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
index 0ee9ddac815b8..417df006ab7c2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
@@ -34,8 +34,9 @@ object DataTypeTestUtils {
* decimal types.
*/
val fractionalTypes: Set[FractionalType] = Set(
+ DecimalType.USER_DEFAULT,
+ DecimalType(20, 5),
DecimalType.SYSTEM_DEFAULT,
- DecimalType(2, 1),
DoubleType,
FloatType
)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 1d297beb3868d..6921d15958a55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -166,6 +166,27 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(Decimal(100) % Decimal(0) === null)
}
+ // regression test for SPARK-8359
+ test("accurate precision after multiplication") {
+ val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal
+ assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249")
+ }
+
+ // regression test for SPARK-8677
+ test("fix non-terminating decimal expansion problem") {
+ val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
+ // The difference between decimal should not be more than 0.001.
+ assert(decimal.toDouble - 0.333 < 0.001)
+ }
+
+ // regression test for SPARK-8800
+ test("fix loss of precision/scale when doing division operation") {
+ val a = Decimal(2) / Decimal(3)
+ assert(a.toDouble < 1.0 && a.toDouble > 0.6)
+ val b = Decimal(1) / Decimal(8)
+ assert(b.toDouble === 0.125)
+ }
+
test("set/setOrNull") {
assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L)
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index be0966641b5c4..349007789f634 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -106,6 +106,11 @@
parquet-avrotest
+
+ org.mockito
+ mockito-core
+ test
+ target/scala-${scala.binary.version}/classes
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
similarity index 60%
rename from sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
rename to sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 03f4c3ed8e6bb..09511ff35f785 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -15,14 +15,21 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions;
+package org.apache.spark.sql.execution;
-import java.util.Iterator;
+import java.io.IOException;
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.KVIterator;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
@@ -61,26 +68,13 @@ public final class UnsafeFixedWidthAggregationMap {
private final boolean enablePerfMetrics;
- /**
- * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
- * false otherwise.
- */
- public static boolean supportsGroupKeySchema(StructType schema) {
- for (StructField field: schema.fields()) {
- if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
- return false;
- }
- }
- return true;
- }
-
/**
* @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
* schema, false otherwise.
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
- if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+ if (!UnsafeRow.isMutable(field.dataType())) {
return false;
}
}
@@ -93,7 +87,9 @@ public static boolean supportsAggregationBufferSchema(StructType schema) {
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
- * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
+ * other tasks.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -102,31 +98,35 @@ public UnsafeFixedWidthAggregationMap(
InternalRow emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
- this.map =
- new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+ this.map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
- assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
- UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
}
/**
* Return the aggregation buffer for the current group. For efficiency, all calls to this method
- * return the same object.
+ * return the same object. If additional memory could not be allocated, then this method will
+ * signal an error by returning null.
*/
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
+ return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
+ }
+
+ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
unsafeGroupingKeyRow.getBaseObject(),
@@ -135,14 +135,17 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
- loc.putNewKey(
+ boolean putSucceeded = loc.putNewKey(
unsafeGroupingKeyRow.getBaseObject(),
unsafeGroupingKeyRow.getBaseOffset(),
unsafeGroupingKeyRow.getSizeInBytes(),
emptyAggregationBuffer,
- PlatformDependent.BYTE_ARRAY_OFFSET,
+ Platform.BYTE_ARRAY_OFFSET,
emptyAggregationBuffer.length
);
+ if (!putSucceeded) {
+ return null;
+ }
}
// Reset the pointer to point to the value that we just stored or looked up:
@@ -157,59 +160,75 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
}
/**
- * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
- */
- public static class MapEntry {
- private MapEntry() { };
- public final UnsafeRow key = new UnsafeRow();
- public final UnsafeRow value = new UnsafeRow();
- }
-
- /**
- * Returns an iterator over the keys and values in this map.
+ * Returns an iterator over the keys and values in this map. This uses destructive iterator of
+ * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has
+ * been called.
*
* For efficiency, each call returns the same object.
*/
- public Iterator iterator() {
- return new Iterator() {
+ public KVIterator iterator() {
+ return new KVIterator() {
+
+ private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator =
+ map.destructiveIterator();
+ private final UnsafeRow key = new UnsafeRow();
+ private final UnsafeRow value = new UnsafeRow();
- private final MapEntry entry = new MapEntry();
- private final Iterator mapLocationIterator = map.iterator();
+ @Override
+ public boolean next() {
+ if (mapLocationIterator.hasNext()) {
+ final BytesToBytesMap.Location loc = mapLocationIterator.next();
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ key.pointTo(
+ keyAddress.getBaseObject(),
+ keyAddress.getBaseOffset(),
+ groupingKeySchema.length(),
+ loc.getKeyLength()
+ );
+ value.pointTo(
+ valueAddress.getBaseObject(),
+ valueAddress.getBaseOffset(),
+ aggregationBufferSchema.length(),
+ loc.getValueLength()
+ );
+ return true;
+ } else {
+ return false;
+ }
+ }
@Override
- public boolean hasNext() {
- return mapLocationIterator.hasNext();
+ public UnsafeRow getKey() {
+ return key;
}
@Override
- public MapEntry next() {
- final BytesToBytesMap.Location loc = mapLocationIterator.next();
- final MemoryLocation keyAddress = loc.getKeyAddress();
- final MemoryLocation valueAddress = loc.getValueAddress();
- entry.key.pointTo(
- keyAddress.getBaseObject(),
- keyAddress.getBaseOffset(),
- groupingKeySchema.length(),
- loc.getKeyLength()
- );
- entry.value.pointTo(
- valueAddress.getBaseObject(),
- valueAddress.getBaseOffset(),
- aggregationBufferSchema.length(),
- loc.getValueLength()
- );
- return entry;
+ public UnsafeRow getValue() {
+ return value;
}
@Override
- public void remove() {
- throw new UnsupportedOperationException();
+ public void close() {
+ // Do nothing.
}
};
}
/**
- * Free the unsafe memory associated with this map.
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ return map.getPeakMemoryUsedBytes();
+ }
+
+ @VisibleForTesting
+ public int getNumDataPages() {
+ return map.getNumDataPages();
+ }
+
+ /**
+ * Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
public void free() {
map.free();
@@ -226,4 +245,18 @@ public void printPerfMetrics() {
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
}
+ /**
+ * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]]
+ * that can be used to insert more records to do external sorting.
+ *
+ * The only memory that is allocated is the address/prefix array, 16 bytes per record.
+ *
+ * Note that this destroys the map, and as a result, the map cannot be used anymore after this.
+ */
+ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
+ UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
+ groupingKeySchema, aggregationBufferSchema,
+ SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map);
+ return sorter;
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
new file mode 100644
index 0000000000000..7db6b7ff50f22
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import javax.annotation.Nullable;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.KVIterator;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.*;
+
+/**
+ * A class for performing external sorting on key-value records. Both key and value are UnsafeRows.
+ *
+ * Note that this class allows optionally passing in a {@link BytesToBytesMap} directly in order
+ * to perform in-place sorting of records in the map.
+ */
+public final class UnsafeKVExternalSorter {
+
+ private final StructType keySchema;
+ private final StructType valueSchema;
+ private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
+ private final UnsafeExternalSorter sorter;
+
+ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+ BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes)
+ throws IOException {
+ this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null);
+ }
+
+ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+ BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes,
+ @Nullable BytesToBytesMap map) throws IOException {
+ this.keySchema = keySchema;
+ this.valueSchema = valueSchema;
+ final TaskContext taskContext = TaskContext.get();
+
+ prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
+ PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
+ BaseOrdering ordering = GenerateOrdering.create(keySchema);
+ KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
+
+ TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
+
+ if (map == null) {
+ sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ /* initialSize */ 4096,
+ pageSizeBytes);
+ } else {
+ // Insert the records into the in-memory sorter.
+ // We will use the number of elements in the map as the initialSize of the
+ // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize,
+ // we will use 1 as its initial size if the map is empty.
+ final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
+ taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
+
+ // We cannot use the destructive iterator here because we are reusing the existing memory
+ // pages in BytesToBytesMap to hold records during sorting.
+ // The only new memory we are allocating is the pointer/prefix array.
+ BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+ final int numKeyFields = keySchema.size();
+ UnsafeRow row = new UnsafeRow();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ final Object baseObject = loc.getKeyAddress().getBaseObject();
+ final long baseOffset = loc.getKeyAddress().getBaseOffset();
+
+ // Get encoded memory address
+ // baseObject + baseOffset point to the beginning of the key data in the map, but that
+ // the KV-pair's length data is stored in the word immediately before that address
+ MemoryBlock page = loc.getMemoryPage();
+ long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
+
+ // Compute prefix
+ row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+ final long prefix = prefixComputer.computePrefix(row);
+
+ inMemSorter.insertRecord(address, prefix);
+ }
+
+ sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
+ taskContext.taskMemoryManager(),
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ new KVComparator(ordering, keySchema.length()),
+ prefixComparator,
+ /* initialSize */ 4096,
+ pageSizeBytes,
+ inMemSorter);
+
+ sorter.spill();
+ map.free();
+ }
+ }
+
+ /**
+ * Inserts a key-value record into the sorter. If the sorter no longer has enough memory to hold
+ * the record, the sorter sorts the existing records in-memory, writes them out as partially
+ * sorted runs, and then reallocates memory to hold the new record.
+ */
+ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
+ final long prefix = prefixComputer.computePrefix(key);
+ sorter.insertKVRecord(
+ key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
+ }
+
+ /**
+ * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
+ * after consuming this iterator.
+ */
+ public KVSorterIterator sortedIterator() throws IOException {
+ try {
+ final UnsafeSorterIterator underlying = sorter.getSortedIterator();
+ if (!underlying.hasNext()) {
+ // Since we won't ever call next() on an empty iterator, we need to clean up resources
+ // here in order to prevent memory leaks.
+ cleanupResources();
+ }
+ return new KVSorterIterator(underlying);
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ /**
+ * Return the peak memory used so far, in bytes.
+ */
+ public long getPeakMemoryUsedBytes() {
+ return sorter.getPeakMemoryUsedBytes();
+ }
+
+ /**
+ * Marks the current page as no-more-space-available, and as a result, either allocate a
+ * new page or spill when we see the next record.
+ */
+ @VisibleForTesting
+ void closeCurrentPage() {
+ sorter.closeCurrentPage();
+ }
+
+ /**
+ * Frees this sorter's in-memory data structures and cleans up its spill files.
+ */
+ public void cleanupResources() {
+ sorter.cleanupResources();
+ }
+
+ private static final class KVComparator extends RecordComparator {
+ private final BaseOrdering ordering;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+ private final int numKeyFields;
+
+ public KVComparator(BaseOrdering ordering, int numKeyFields) {
+ this.numKeyFields = numKeyFields;
+ this.ordering = ordering;
+ }
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ // Note that since ordering doesn't need the total length of the record, we just pass -1
+ // into the row.
+ row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
+ row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+ return ordering.compare(row1, row2);
+ }
+ }
+
+ public class KVSorterIterator extends KVIterator {
+ private UnsafeRow key = new UnsafeRow();
+ private UnsafeRow value = new UnsafeRow();
+ private final int numKeyFields = keySchema.size();
+ private final int numValueFields = valueSchema.size();
+ private final UnsafeSorterIterator underlying;
+
+ private KVSorterIterator(UnsafeSorterIterator underlying) {
+ this.underlying = underlying;
+ }
+
+ @Override
+ public boolean next() throws IOException {
+ try {
+ if (underlying.hasNext()) {
+ underlying.loadNext();
+
+ Object baseObj = underlying.getBaseObject();
+ long recordOffset = underlying.getBaseOffset();
+ int recordLen = underlying.getRecordLength();
+
+ // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
+ int keyLen = Platform.getInt(baseObj, recordOffset);
+ int valueLen = recordLen - keyLen - 4;
+ key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+ value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
+
+ return true;
+ } else {
+ key = null;
+ value = null;
+ cleanupResources();
+ return false;
+ }
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ cleanupResources();
+ }
+ };
+}
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
new file mode 100644
index 0000000000000..ca50000b4756e
--- /dev/null
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -0,0 +1,3 @@
+org.apache.spark.sql.execution.datasources.jdbc.DefaultSource
+org.apache.spark.sql.execution.datasources.json.DefaultSource
+org.apache.spark.sql.execution.datasources.parquet.DefaultSource
diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
new file mode 100644
index 0000000000000..ddd3a91dd8ef8
--- /dev/null
+++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#plan-viz-graph .label {
+ font-weight: normal;
+ text-shadow: none;
+}
+
+#plan-viz-graph svg g.node rect {
+ fill: #C3EBFF;
+ stroke: #3EC0FF;
+ stroke-width: 1px;
+}
+
+/* Hightlight the SparkPlan node name */
+#plan-viz-graph svg text :first-child {
+ font-weight: bold;
+}
+
+#plan-viz-graph svg path {
+ stroke: #444;
+ stroke-width: 1.5px;
+}
diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js
new file mode 100644
index 0000000000000..5161fcde669e7
--- /dev/null
+++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.js
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+var PlanVizConstants = {
+ svgMarginX: 16,
+ svgMarginY: 16
+};
+
+function renderPlanViz() {
+ var svg = planVizContainer().append("svg");
+ var metadata = d3.select("#plan-viz-metadata");
+ var dot = metadata.select(".dot-file").text().trim();
+ var graph = svg.append("g");
+
+ var g = graphlibDot.read(dot);
+ preprocessGraphLayout(g);
+ var renderer = new dagreD3.render();
+ renderer(graph, g);
+
+ // Round corners on rectangles
+ svg
+ .selectAll("rect")
+ .attr("rx", "5")
+ .attr("ry", "5");
+
+ var nodeSize = parseInt($("#plan-viz-metadata-size").text());
+ for (var i = 0; i < nodeSize; i++) {
+ setupTooltipForSparkPlanNode(i);
+ }
+
+ resizeSvg(svg)
+}
+
+/* -------------------- *
+ * | Helper functions | *
+ * -------------------- */
+
+function planVizContainer() { return d3.select("#plan-viz-graph"); }
+
+/*
+ * Set up the tooltip for a SparkPlan node using metadata. When the user moves the mouse on the
+ * node, it will display the details of this SparkPlan node in the right.
+ */
+function setupTooltipForSparkPlanNode(nodeId) {
+ var nodeTooltip = d3.select("#plan-meta-data-" + nodeId).text()
+ d3.select("svg g .node_" + nodeId)
+ .on('mouseover', function(d) {
+ var domNode = d3.select(this).node();
+ $(domNode).tooltip({
+ title: nodeTooltip, trigger: "manual", container: "body", placement: "right"
+ });
+ $(domNode).tooltip("show");
+ })
+ .on('mouseout', function(d) {
+ var domNode = d3.select(this).node();
+ $(domNode).tooltip("destroy");
+ })
+}
+
+/*
+ * Helper function to pre-process the graph layout.
+ * This step is necessary for certain styles that affect the positioning
+ * and sizes of graph elements, e.g. padding, font style, shape.
+ */
+function preprocessGraphLayout(g) {
+ var nodes = g.nodes();
+ for (var i = 0; i < nodes.length; i++) {
+ var node = g.node(nodes[i]);
+ node.padding = "5";
+ }
+ // Curve the edges
+ var edges = g.edges();
+ for (var j = 0; j < edges.length; j++) {
+ var edge = g.edge(edges[j]);
+ edge.lineInterpolate = "basis";
+ }
+}
+
+/*
+ * Helper function to size the SVG appropriately such that all elements are displayed.
+ * This assumes that all outermost elements are clusters (rectangles).
+ */
+function resizeSvg(svg) {
+ var allClusters = svg.selectAll("g rect")[0];
+ console.log(allClusters);
+ var startX = -PlanVizConstants.svgMarginX +
+ toFloat(d3.min(allClusters, function(e) {
+ console.log(e);
+ return getAbsolutePosition(d3.select(e)).x;
+ }));
+ var startY = -PlanVizConstants.svgMarginY +
+ toFloat(d3.min(allClusters, function(e) {
+ return getAbsolutePosition(d3.select(e)).y;
+ }));
+ var endX = PlanVizConstants.svgMarginX +
+ toFloat(d3.max(allClusters, function(e) {
+ var t = d3.select(e);
+ return getAbsolutePosition(t).x + toFloat(t.attr("width"));
+ }));
+ var endY = PlanVizConstants.svgMarginY +
+ toFloat(d3.max(allClusters, function(e) {
+ var t = d3.select(e);
+ return getAbsolutePosition(t).y + toFloat(t.attr("height"));
+ }));
+ var width = endX - startX;
+ var height = endY - startY;
+ svg.attr("viewBox", startX + " " + startY + " " + width + " " + height)
+ .attr("width", width)
+ .attr("height", height);
+}
+
+/* Helper function to convert attributes to numeric values. */
+function toFloat(f) {
+ if (f) {
+ return parseFloat(f.toString().replace(/px$/, ""));
+ } else {
+ return f;
+ }
+}
+
+/*
+ * Helper function to compute the absolute position of the specified element in our graph.
+ */
+function getAbsolutePosition(d3selection) {
+ if (d3selection.empty()) {
+ throw "Attempted to get absolute position of an empty selection.";
+ }
+ var obj = d3selection;
+ var _x = toFloat(obj.attr("x")) || 0;
+ var _y = toFloat(obj.attr("y")) || 0;
+ while (!obj.empty()) {
+ var transformText = obj.attr("transform");
+ if (transformText) {
+ var translate = d3.transform(transformText).translate;
+ _x += toFloat(translate[0]);
+ _y += toFloat(translate[1]);
+ }
+ // Climb upwards to find how our parents are translated
+ obj = d3.select(obj.node().parentNode);
+ // Stop when we've reached the graph container itself
+ if (obj.node() == planVizContainer().node()) {
+ break;
+ }
+ }
+ return { x: _x, y: _y };
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b25dcbca82b9f..27bd084847346 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -54,7 +54,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def this(name: String) = this(name match {
case "*" => UnresolvedStar(None)
case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
- case _ => UnresolvedAttribute(name)
+ case _ => UnresolvedAttribute.quotedString(name)
})
/** Creates a column based on the given expression. */
@@ -627,8 +627,19 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @group expr_ops
* @since 1.3.0
*/
+ @deprecated("use isin", "1.5.0")
@scala.annotation.varargs
- def in(list: Any*): Column = In(expr, list.map(lit(_).expr))
+ def in(list: Any*): Column = isin(list : _*)
+
+ /**
+ * A boolean expression that is evaluated to true if the value of this expression is contained
+ * by the evaluated values of the arguments.
+ *
+ * @group expr_ops
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def isin(list: Any*): Column = In(expr, list.map(lit(_).expr))
/**
* SQL like expression.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 3ea0f9ed3bddd..c466d9e6cb349 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.util.Properties
-import org.apache.spark.unsafe.types.UTF8String
-
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
@@ -36,12 +34,12 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
-import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
+import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution}
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
-import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation}
+import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -54,7 +52,6 @@ private[sql] object DataFrame {
}
}
-
/**
* :: Experimental ::
* A distributed collection of data organized into named columns.
@@ -119,6 +116,9 @@ class DataFrame private[sql](
@transient val sqlContext: SQLContext,
@DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable {
+ // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure
+ // you wrap it with `withNewExecutionId` if this actions doesn't call other action.
+
/**
* A constructor that automatically analyzes the logical plan.
*
@@ -168,7 +168,7 @@ class DataFrame private[sql](
}
/**
- * Internal API for Python
+ * Compose the string representing rows for output
* @param _numRows Number of rows to show
* @param truncate Whether truncate long strings and align cells right
*/
@@ -687,9 +687,7 @@ class DataFrame private[sql](
case Column(explode: Explode) => MultiAlias(explode, Nil)
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
}
- // When user continuously call `select`, speed up analysis by collapsing `Project`
- import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
- Project(namedExpressions.toSeq, ProjectCollapsing(logicalPlan))
+ Project(namedExpressions.toSeq, logicalPlan)
}
/**
@@ -1356,14 +1354,18 @@ class DataFrame private[sql](
* @group rdd
* @since 1.3.0
*/
- def foreach(f: Row => Unit): Unit = rdd.foreach(f)
+ def foreach(f: Row => Unit): Unit = withNewExecutionId {
+ rdd.foreach(f)
+ }
/**
* Applies a function f to each partition of this [[DataFrame]].
* @group rdd
* @since 1.3.0
*/
- def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
+ def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId {
+ rdd.foreachPartition(f)
+ }
/**
* Returns the first `n` rows in the [[DataFrame]].
@@ -1377,14 +1379,18 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
+ def collect(): Array[Row] = withNewExecutionId {
+ queryExecution.executedPlan.executeCollect()
+ }
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
* @group action
* @since 1.3.0
*/
- def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*)
+ def collectAsList(): java.util.List[Row] = withNewExecutionId {
+ java.util.Arrays.asList(rdd.collect() : _*)
+ }
/**
* Returns the number of rows in the [[DataFrame]].
@@ -1554,10 +1560,10 @@ class DataFrame private[sql](
*/
def inputFiles: Array[String] = {
val files: Seq[String] = logicalPlan.collect {
- case LogicalRelation(fsBasedRelation: HadoopFsRelation) =>
- fsBasedRelation.paths.toSeq
- case LogicalRelation(jsonRelation: JSONRelation) =>
- jsonRelation.path.toSeq
+ case LogicalRelation(fsBasedRelation: FileRelation) =>
+ fsBasedRelation.inputFiles
+ case fr: FileRelation =>
+ fr.inputFiles
}.flatten
files.toSet.toArray
}
@@ -1643,8 +1649,12 @@ class DataFrame private[sql](
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
*
- * Also note that while this function can persist the table metadata into Hive's metastore,
- * the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @group output
* @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`.
*/
@@ -1662,8 +1672,12 @@ class DataFrame private[sql](
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
*
- * Also note that while this function can persist the table metadata into Hive's metastore,
- * the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @group output
* @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`.
*/
@@ -1682,8 +1696,12 @@ class DataFrame private[sql](
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
*
- * Also note that while this function can persist the table metadata into Hive's metastore,
- * the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @group output
* @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`.
*/
@@ -1702,8 +1720,12 @@ class DataFrame private[sql](
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
*
- * Also note that while this function can persist the table metadata into Hive's metastore,
- * the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @group output
* @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`.
*/
@@ -1721,8 +1743,12 @@ class DataFrame private[sql](
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
*
- * Also note that while this function can persist the table metadata into Hive's metastore,
- * the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @group output
* @deprecated As of 1.4.0, replaced by
* `write().format(source).mode(mode).options(options).saveAsTable(tableName)`.
@@ -1747,8 +1773,12 @@ class DataFrame private[sql](
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
*
- * Also note that while this function can persist the table metadata into Hive's metastore,
- * the table will NOT be accessible from Hive, until SPARK-7550 is resolved.
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @group output
* @deprecated As of 1.4.0, replaced by
* `write().format(source).mode(mode).options(options).saveAsTable(tableName)`.
@@ -1863,6 +1893,14 @@ class DataFrame private[sql](
write.mode(SaveMode.Append).insertInto(tableName)
}
+ /**
+ * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with
+ * an execution.
+ */
+ private[sql] def withNewExecutionId[T](body: => T): T = {
+ SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
+ }
+
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
// End of deprecated methods
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index eb09807f9d9c2..9ea955b010017 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -25,10 +25,10 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
+import org.apache.spark.sql.execution.datasources.json.JSONRelation
+import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
-import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
-import org.apache.spark.sql.json.JSONRelation
-import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.{Logging, Partition}
@@ -237,7 +237,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
def json(jsonRDD: RDD[String]): DataFrame = {
val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble
sqlContext.baseRelationToDataFrame(
- new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext))
+ new JSONRelation(Some(jsonRDD), samplingRatio, userSpecifiedSchema, None, None)(sqlContext))
}
/**
@@ -260,7 +260,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
sqlContext.baseRelationToDataFrame(
new ParquetRelation(
- globbedPaths.map(_.toString), None, None, extraOptions.toMap)(sqlContext))
+ globbedPaths.map(_.toString), userSpecifiedSchema, None, extraOptions.toMap)(sqlContext))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 4ec58082e7aef..2e68e358f2f1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql
+import java.{util => ju, lang => jl}
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
@@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
+ require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+ s"Fractions must be in [0, 1], but got $fractions.")
+ import org.apache.spark.sql.functions.{rand, udf}
+ val c = Column(col)
+ val r = rand(seed)
+ val f = udf { (stratum: Any, x: Double) =>
+ x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
+ }
+ df.filter(f(c, r))
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 7e3318cefe62c..5fa11da4c38cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -23,8 +23,9 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource}
-import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils}
+import org.apache.spark.sql.sources.HadoopFsRelation
/**
@@ -185,6 +186,12 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* When `mode` is `Append`, the schema of the [[DataFrame]] need to be
* the same as that of the existing table, and format or options will be ignored.
*
+ * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input
+ * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
+ * and Parquet), the table is persisted in a Hive compatible format, which means other systems
+ * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL
+ * specific format.
+ *
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
@@ -257,7 +264,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
// Create the table if the table didn't exist.
if (!tableExists) {
- val schema = JDBCWriteDetails.schemaString(df, url)
+ val schema = JdbcUtils.schemaString(df, url)
val sql = s"CREATE TABLE $table ($schema)"
conn.prepareStatement(sql).executeUpdate()
}
@@ -265,7 +272,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
conn.close()
}
- JDBCWriteDetails.saveTable(df, url, table, connectionProperties)
+ JdbcUtils.saveTable(df, url, table, connectionProperties)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 6644e85d4a037..e9de14f025502 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -200,7 +200,7 @@ private[spark] object SQLConf {
val IN_MEMORY_PARTITION_PRUNING =
booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning",
- defaultValue = Some(false),
+ defaultValue = Some(true),
doc = "When true, enable partition pruning for in-memory columnar tables.",
isPublic = false)
@@ -223,14 +223,21 @@ private[spark] object SQLConf {
defaultValue = Some(200),
doc = "The default number of partitions to use when shuffling data for joins or aggregations.")
- val CODEGEN_ENABLED = booleanConf("spark.sql.codegen",
+ val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled",
defaultValue = Some(true),
+ doc = "When true, use the optimized Tungsten physical execution backend which explicitly " +
+ "manages memory and dynamically generates bytecode for expression evaluation.")
+
+ val CODEGEN_ENABLED = booleanConf("spark.sql.codegen",
+ defaultValue = Some(true), // use TUNGSTEN_ENABLED as default
doc = "When true, code will be dynamically generated at runtime for expression evaluation in" +
- " a specific query.")
+ " a specific query.",
+ isPublic = false)
val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled",
- defaultValue = Some(true),
- doc = "When true, use the new optimized Tungsten physical execution backend.")
+ defaultValue = Some(true), // use TUNGSTEN_ENABLED as default
+ doc = "When true, use the new optimized Tungsten physical execution backend.",
+ isPublic = false)
val DIALECT = stringConf(
"spark.sql.dialect",
@@ -359,17 +366,21 @@ private[spark] object SQLConf {
"storing additional schema information in Hive's metastore.",
isPublic = false)
- // Whether to perform partition discovery when loading external data sources. Default to true.
val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled",
defaultValue = Some(true),
doc = "When true, automtically discover data partitions.")
- // Whether to perform partition column type inference. Default to true.
val PARTITION_COLUMN_TYPE_INFERENCE =
booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled",
defaultValue = Some(true),
doc = "When true, automatically infer the data types for partitioned columns.")
+ val PARTITION_MAX_FILES =
+ intConf("spark.sql.sources.maxConcurrentWrites",
+ defaultValue = Some(5),
+ doc = "The maximum number of concurent files to open before falling back on sorting when " +
+ "writing out files using dynamic partitioning.")
+
// The output committer class used by HadoopFsRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
//
@@ -409,10 +420,6 @@ private[spark] object SQLConf {
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
defaultValue = Some(true), doc = "")
- val USE_SQL_SERIALIZER2 = booleanConf(
- "spark.sql.useSerializer2",
- defaultValue = Some(true), isPublic = false)
-
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -427,7 +434,6 @@ private[spark] object SQLConf {
*
* SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads).
*/
-
private[sql] class SQLConf extends Serializable with CatalystConf {
import SQLConf._
@@ -474,16 +480,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN)
- private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED)
+ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED))
def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
- private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
+ private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED))
private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
- private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
-
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
private[spark] def defaultSizeInBytes: Long =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbb2a09846548..4bf00b3399e7a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -43,6 +43,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
+import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -74,6 +75,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
protected[sql] def conf = currentSession().conf
+ // `listener` should be only used in the driver
+ @transient private[sql] val listener = new SQLListener(this)
+ sparkContext.addSparkListener(listener)
+ sparkContext.ui.foreach(new SQLTab(this, _))
+
/**
* Set Spark SQL configuration properties.
*
@@ -285,9 +291,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
val udf: UDFRegistration = new UDFRegistration(this)
- @transient
- val udaf: UDAFRegistration = new UDAFRegistration(this)
-
/**
* Returns true if the table is currently cached in-memory.
* @group cachemgmt
@@ -340,7 +343,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args : _*))
+ new ColumnName(sc.s(args: _*))
}
}
@@ -870,7 +873,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
HashAggregation ::
Aggregation ::
LeftSemiJoin ::
- HashJoin ::
+ EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
CartesianProduct ::
@@ -1008,9 +1011,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
def output =
analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ")
- // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)})
- // however, the `toRdd` will cause the real execution, which is not what we want.
- // We need to think about how to avoid the side effect.
s"""== Parsed Logical Plan ==
|${stringOrError(logical)}
|== Analyzed Logical Plan ==
@@ -1021,7 +1021,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
|== Physical Plan ==
|${stringOrError(executedPlan)}
|Code Generation: ${stringOrError(executedPlan.codegenEnabled)}
- |== RDD ==
""".stripMargin.trim
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 7cd7421a518c9..1f270560d7bc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -26,6 +26,8 @@ import org.apache.spark.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.DataType
/**
@@ -52,6 +54,20 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
functionRegistry.registerFunction(name, udf.builder)
}
+ /**
+ * Register a user-defined aggregate function (UDAF).
+ * @param name the name of the UDAF.
+ * @param udaf the UDAF needs to be registered.
+ * @return the registered UDAF.
+ */
+ def register(
+ name: String,
+ udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+ def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
+ functionRegistry.registerFunction(name, builder)
+ udaf
+ }
+
// scalastyle:off
/* register 0-22 were generated by this script
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 454b7b91a63f5..1620fc401ba6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder(
precision: Int,
scale: Int)
extends NativeColumnBuilder(
- new FixedDecimalColumnStats,
+ new FixedDecimalColumnStats(precision, scale),
FIXED_DECIMAL(precision, scale))
// TODO (lian) Add support for array, struct and map
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 32a84b2676e07..5cbd52bc0590e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -66,7 +66,7 @@ private[sql] sealed trait ColumnStats extends Serializable {
* Column statistics represented as a single row, currently including closed lower bound, closed
* upper bound and null count.
*/
- def collectedStatistics: InternalRow
+ def collectedStatistics: GenericInternalRow
}
/**
@@ -75,7 +75,8 @@ private[sql] sealed trait ColumnStats extends Serializable {
private[sql] class NoopColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
- override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, 0L)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L))
}
private[sql] class BooleanColumnStats extends ColumnStats {
@@ -92,8 +93,8 @@ private[sql] class BooleanColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class ByteColumnStats extends ColumnStats {
@@ -110,8 +111,8 @@ private[sql] class ByteColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class ShortColumnStats extends ColumnStats {
@@ -128,8 +129,8 @@ private[sql] class ShortColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class IntColumnStats extends ColumnStats {
@@ -146,8 +147,8 @@ private[sql] class IntColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class LongColumnStats extends ColumnStats {
@@ -164,8 +165,8 @@ private[sql] class LongColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class FloatColumnStats extends ColumnStats {
@@ -182,8 +183,8 @@ private[sql] class FloatColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class DoubleColumnStats extends ColumnStats {
@@ -200,8 +201,8 @@ private[sql] class DoubleColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class StringColumnStats extends ColumnStats {
@@ -218,8 +219,8 @@ private[sql] class StringColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class BinaryColumnStats extends ColumnStats {
@@ -230,26 +231,26 @@ private[sql] class BinaryColumnStats extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(null, null, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
}
-private[sql] class FixedDecimalColumnStats extends ColumnStats {
+private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
- val value = row.getDecimal(ordinal)
+ val value = row.getDecimal(ordinal, precision, scale)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(lower, upper, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
@@ -262,8 +263,8 @@ private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
}
}
- override def collectedStatistics: InternalRow =
- InternalRow(null, null, nullCount, count, sizeInBytes)
+ override def collectedStatistics: GenericInternalRow =
+ new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
}
private[sql] class DateColumnStats extends IntColumnStats
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2863f6c230a9d..531a8244d55d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -329,7 +329,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = {
- row.update(ordinal, value)
+ row.update(ordinal, value.clone())
}
override def getField(row: InternalRow, ordinal: Int): UTF8String = {
@@ -337,7 +337,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
}
override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
- to.update(toOrdinal, from.getUTF8String(fromOrdinal))
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
}
@@ -392,11 +392,15 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
override def getField(row: InternalRow, ordinal: Int): Decimal = {
- row.getDecimal(ordinal)
+ row.getDecimal(ordinal, precision, scale)
}
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
- row(ordinal) = value
+ row.setDecimal(ordinal, value, precision)
+ }
+
+ override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) {
+ setField(to, toOrdinal, getField(from, fromOrdinal))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 5d5b0697d7016..d553bb6169ecc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -148,7 +148,7 @@ private[sql] case class InMemoryRelation(
}
val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
- .flatMap(_.toSeq))
+ .flatMap(_.values))
batchStats += stats
CachedBatch(columnBuilders.map(_.build().array()), stats)
@@ -330,10 +330,11 @@ private[sql] case class InMemoryColumnarTableScan(
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter(cachedBatch.stats)) {
- def statsString: String = relation.partitionStatistics.schema
- .zip(cachedBatch.stats.toSeq)
- .map { case (a, s) => s"${a.name}: $s" }
- .mkString(", ")
+ def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map {
+ case (a, i) =>
+ val value = cachedBatch.stats.get(i, a.dataType)
+ s"${a.name}: $value"
+ }.mkString(", ")
logInfo(s"Skipping partition based on stats $statsString")
false
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index e8c6a0f8f801d..f3b6a3a5f4a33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -45,6 +46,10 @@ case class Aggregate(
child: SparkPlan)
extends UnaryNode {
+ override private[sql] lazy val metrics = Map(
+ "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
override def requiredChildDistribution: List[Distribution] = {
if (partial) {
UnspecifiedDistribution :: Nil
@@ -121,12 +126,15 @@ case class Aggregate(
}
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ val numInputRows = longMetric("numInputRows")
+ val numOutputRows = longMetric("numOutputRows")
if (groupingExpressions.isEmpty) {
child.execute().mapPartitions { iter =>
val buffer = newAggregateBuffer()
var currentRow: InternalRow = null
while (iter.hasNext) {
currentRow = iter.next()
+ numInputRows += 1
var i = 0
while (i < buffer.length) {
buffer(i).update(currentRow)
@@ -142,6 +150,7 @@ case class Aggregate(
i += 1
}
+ numOutputRows += 1
Iterator(resultProjection(aggregateResults))
}
} else {
@@ -152,6 +161,7 @@ case class Aggregate(
var currentRow: InternalRow = null
while (iter.hasNext) {
currentRow = iter.next()
+ numInputRows += 1
val currentGroup = groupingProjection(currentRow)
var currentBuffer = hashTable.get(currentGroup)
if (currentBuffer == null) {
@@ -180,6 +190,7 @@ case class Aggregate(
val currentEntry = hashTableIter.next()
val currentGroup = currentEntry.getKey
val currentBuffer = currentEntry.getValue
+ numOutputRows += 1
var i = 0
while (i < currentBuffer.length) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 70e5031fb63c0..029f2264a6a27 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
@@ -39,20 +40,26 @@ import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEn
@DeveloperApi
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
- override def outputPartitioning: Partitioning = newPartitioning
+ override def nodeName: String = if (tungstenMode) "TungstenExchange" else "Exchange"
- override def output: Seq[Attribute] = child.output
+ /**
+ * Returns true iff we can support the data type, and we are not doing range partitioning.
+ */
+ private lazy val tungstenMode: Boolean = {
+ unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) &&
+ !newPartitioning.isInstanceOf[RangePartitioning]
+ }
- override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def outputPartitioning: Partitioning = newPartitioning
- override def canProcessSafeRows: Boolean = true
+ override def output: Seq[Attribute] = child.output
- override def canProcessUnsafeRows: Boolean = {
- // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to
- // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to
- // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed.
- !newPartitioning.isInstanceOf[RangePartitioning]
- }
+ // This setting is somewhat counterintuitive:
+ // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row,
+ // so the planner inserts a converter to convert data into UnsafeRow if needed.
+ override def outputsUnsafeRows: Boolean = tungstenMode
+ override def canProcessSafeRows: Boolean = !tungstenMode
+ override def canProcessUnsafeRows: Boolean = tungstenMode
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
@@ -124,23 +131,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
private val serializer: Serializer = {
val rowDataTypes = child.output.map(_.dataType).toArray
- // It is true when there is no field that needs to be write out.
- // For now, we will not use SparkSqlSerializer2 when noField is true.
- val noField = rowDataTypes == null || rowDataTypes.length == 0
-
- val useSqlSerializer2 =
- child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
- SparkSqlSerializer2.support(rowDataTypes) && // The schema of row is supported.
- !noField
-
- if (child.outputsUnsafeRows) {
- logInfo("Using UnsafeRowSerializer.")
+ if (tungstenMode) {
new UnsafeRowSerializer(child.output.size)
- } else if (useSqlSerializer2) {
- logInfo("Using SparkSqlSerializer2.")
- new SparkSqlSerializer2(rowDataTypes)
} else {
- logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
}
}
@@ -156,7 +149,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
val mutablePair = new MutablePair[InternalRow, Null]()
iter.map(row => mutablePair.update(row.copy(), null))
}
- implicit val ordering = new RowOrdering(sortingExpressions, child.output)
+ // We need to use an interpreted ordering here because generated orderings cannot be
+ // serialized and this ordering needs to be created on the driver in order to be passed into
+ // Spark core code.
+ implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
case SinglePartition =>
new Partitioner {
@@ -194,108 +190,72 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
- * required input partition ordering requirements are met.
+ * input partition ordering requirements are met.
*/
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
- def numPartitions: Int = sqlContext.conf.numShufflePartitions
+ private def numPartitions: Int = sqlContext.conf.numShufflePartitions
- def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case operator: SparkPlan =>
- // True iff every child's outputPartitioning satisfies the corresponding
- // required data distribution.
- def meetsRequirements: Boolean =
- operator.requiredChildDistribution.zip(operator.children).forall {
- case (required, child) =>
- val valid = child.outputPartitioning.satisfies(required)
- logDebug(
- s"${if (valid) "Valid" else "Invalid"} distribution," +
- s"required: $required current: ${child.outputPartitioning}")
- valid
- }
-
- // True iff any of the children are incorrectly sorted.
- def needsAnySort: Boolean =
- operator.requiredChildOrdering.zip(operator.children).exists {
- case (required, child) => required.nonEmpty && required != child.outputOrdering
- }
-
- // True iff outputPartitionings of children are compatible with each other.
- // It is possible that every child satisfies its required data distribution
- // but two children have incompatible outputPartitionings. For example,
- // A dataset is range partitioned by "a.asc" (RangePartitioning) and another
- // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two
- // datasets are both clustered by "a", but these two outputPartitionings are not
- // compatible.
- // TODO: ASSUMES TRANSITIVITY?
- def compatible: Boolean =
- operator.children
- .map(_.outputPartitioning)
- .sliding(2)
- .forall {
- case Seq(a) => true
- case Seq(a, b) => a.compatibleWith(b)
- }
-
- // Adds Exchange or Sort operators as required
- def addOperatorsIfNecessary(
- partitioning: Partitioning,
- rowOrdering: Seq[SortOrder],
- child: SparkPlan): SparkPlan = {
+ /**
+ * Given a required distribution, returns a partitioning that satisfies that distribution.
+ */
+ private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
+ requiredDistribution match {
+ case AllTuples => SinglePartition
+ case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
+ case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
+ case dist => sys.error(s"Do not know how to satisfy distribution $dist")
+ }
+ }
- def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
- if (child.outputPartitioning != partitioning) {
- Exchange(partitioning, child)
- } else {
- child
- }
- }
+ private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
+ val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
+ val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
+ var children: Seq[SparkPlan] = operator.children
- def addSortIfNecessary(child: SparkPlan): SparkPlan = {
+ // Ensure that the operator's children satisfy their output distribution requirements:
+ children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
+ if (child.outputPartitioning.satisfies(distribution)) {
+ child
+ } else {
+ Exchange(canonicalPartitioning(distribution), child)
+ }
+ }
- if (rowOrdering.nonEmpty) {
- // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
- val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
- if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
- } else {
- child
- }
- } else {
- child
- }
+ // If the operator has multiple children and specifies child output distributions (e.g. join),
+ // then the children's output partitionings must be compatible:
+ if (children.length > 1
+ && requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
+ && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+ children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
+ val targetPartitioning = canonicalPartitioning(distribution)
+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
+ child
+ } else {
+ Exchange(targetPartitioning, child)
}
-
- addSortIfNecessary(addShuffleIfNecessary(child))
}
+ }
- if (meetsRequirements && compatible && !needsAnySort) {
- operator
+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
+ if (requiredOrdering.nonEmpty) {
+ // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
+ val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
+ if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
+ sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child)
+ } else {
+ child
+ }
} else {
- // At least one child does not satisfies its required data distribution or
- // at least one child's outputPartitioning is not compatible with another child's
- // outputPartitioning. In this case, we need to add Exchange operators.
- val requirements =
- (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
-
- val fixedChildren = requirements.zipped.map {
- case (AllTuples, rowOrdering, child) =>
- addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
- case (ClusteredDistribution(clustering), rowOrdering, child) =>
- addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
- case (OrderedDistribution(ordering), rowOrdering, child) =>
- addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
-
- case (UnspecifiedDistribution, Seq(), child) =>
- child
- case (UnspecifiedDistribution, rowOrdering, child) =>
- sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
+ child
+ }
+ }
- case (dist, ordering, _) =>
- sys.error(s"Don't know how to ensure $dist with ordering $ordering")
- }
+ operator.withNewChildren(children)
+ }
- operator.withNewChildren(fixedChildren)
- }
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator: SparkPlan => ensureDistributionAndOrdering(operator)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index da27a753a710f..abb60cf12e3a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}
@@ -95,8 +96,21 @@ private[sql] case class LogicalRDD(
/** Physical plan node for scanning data from an RDD. */
private[sql] case class PhysicalRDD(
output: Seq[Attribute],
- rdd: RDD[InternalRow]) extends LeafNode {
+ rdd: RDD[InternalRow],
+ extraInformation: String) extends LeafNode {
+
protected override def doExecute(): RDD[InternalRow] = rdd
+
+ override def simpleString: String = "Scan " + extraInformation + output.mkString("[", ",", "]")
+}
+
+private[sql] object PhysicalRDD {
+ def createFromDataSource(
+ output: Seq[Attribute],
+ rdd: RDD[InternalRow],
+ relation: BaseRelation): PhysicalRDD = {
+ PhysicalRDD(output, rdd, relation.toString)
+ }
}
/** Logical plan node for scanning data from a local collection. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala
new file mode 100644
index 0000000000000..7a2a9eed5807d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+/**
+ * An interface for relations that are backed by files. When a class implements this interface,
+ * the list of paths that it returns will be returned to a user who calls `inputPaths` on any
+ * DataFrame that queries this relation.
+ */
+private[sql] trait FileRelation {
+ /** Returns the list of files that will be read when scanning this relation. */
+ def inputFiles: Array[String]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
deleted file mode 100644
index b85aada9d9d4c..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ /dev/null
@@ -1,336 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.trees._
-import org.apache.spark.sql.types._
-
-case class AggregateEvaluation(
- schema: Seq[Attribute],
- initialValues: Seq[Expression],
- update: Seq[Expression],
- result: Expression)
-
-/**
- * :: DeveloperApi ::
- * Alternate version of aggregation that leverages projection and thus code generation.
- * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
- * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
- *
- * @param partial if true then aggregation is done partially on local data without shuffling to
- * ensure all values where `groupingExpressions` are equal are present.
- * @param groupingExpressions expressions that are evaluated to determine grouping.
- * @param aggregateExpressions expressions that are computed for each group.
- * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
- * @param child the input data source.
- */
-@DeveloperApi
-case class GeneratedAggregate(
- partial: Boolean,
- groupingExpressions: Seq[Expression],
- aggregateExpressions: Seq[NamedExpression],
- unsafeEnabled: Boolean,
- child: SparkPlan)
- extends UnaryNode {
-
- override def requiredChildDistribution: Seq[Distribution] =
- if (partial) {
- UnspecifiedDistribution :: Nil
- } else {
- if (groupingExpressions == Nil) {
- AllTuples :: Nil
- } else {
- ClusteredDistribution(groupingExpressions) :: Nil
- }
- }
-
- override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
-
- protected override def doExecute(): RDD[InternalRow] = {
- val aggregatesToCompute = aggregateExpressions.flatMap { a =>
- a.collect { case agg: AggregateExpression1 => agg}
- }
-
- // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
- // (in test "aggregation with codegen").
- val computeFunctions = aggregatesToCompute.map {
- case c @ Count(expr) =>
- // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
- // UnscaledValue will be null if and only if x is null; helps with Average on decimals
- val toCount = expr match {
- case UnscaledValue(e) => e
- case _ => expr
- }
- val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
- val initialValue = Literal(0L)
- val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
- val result = currentCount
-
- AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
- case s @ Sum(expr) =>
- val calcType =
- expr.dataType match {
- case DecimalType.Fixed(p, s) =>
- DecimalType.bounded(p + 10, s)
- case _ =>
- expr.dataType
- }
-
- val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
- val initialValue = Literal.create(null, calcType)
-
- // Coalesce avoids double calculation...
- // but really, common sub expression elimination would be better....
- val zero = Cast(Literal(0), calcType)
- val updateFunction = Coalesce(
- Add(
- Coalesce(currentSum :: zero :: Nil),
- Cast(expr, calcType)
- ) :: currentSum :: Nil)
- val result =
- expr.dataType match {
- case DecimalType.Fixed(_, _) =>
- Cast(currentSum, s.dataType)
- case _ => currentSum
- }
-
- AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
-
- case m @ Max(expr) =>
- val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
- val initialValue = Literal.create(null, expr.dataType)
- val updateMax = MaxOf(currentMax, expr)
-
- AggregateEvaluation(
- currentMax :: Nil,
- initialValue :: Nil,
- updateMax :: Nil,
- currentMax)
-
- case m @ Min(expr) =>
- val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
- val initialValue = Literal.create(null, expr.dataType)
- val updateMin = MinOf(currentMin, expr)
-
- AggregateEvaluation(
- currentMin :: Nil,
- initialValue :: Nil,
- updateMin :: Nil,
- currentMin)
-
- case CollectHashSet(Seq(expr)) =>
- val set =
- AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
- val initialValue = NewSet(expr.dataType)
- val addToSet = AddItemToSet(expr, set)
-
- AggregateEvaluation(
- set :: Nil,
- initialValue :: Nil,
- addToSet :: Nil,
- set)
-
- case CombineSetsAndCount(inputSet) =>
- val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
- val set =
- AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
- val initialValue = NewSet(elementType)
- val collectSets = CombineSets(set, inputSet)
-
- AggregateEvaluation(
- set :: Nil,
- initialValue :: Nil,
- collectSets :: Nil,
- CountSet(set))
-
- case o => sys.error(s"$o can't be codegened.")
- }
-
- val computationSchema = computeFunctions.flatMap(_.schema)
-
- val resultMap: Map[TreeNodeRef, Expression] =
- aggregatesToCompute.zip(computeFunctions).map {
- case (agg, func) => new TreeNodeRef(agg) -> func.result
- }.toMap
-
- val namedGroups = groupingExpressions.zipWithIndex.map {
- case (ne: NamedExpression, _) => (ne, ne.toAttribute)
- case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
- }
-
- // The set of expressions that produce the final output given the aggregation buffer and the
- // grouping expressions.
- val resultExpressions = aggregateExpressions.map(_.transform {
- case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
- case e: Expression =>
- namedGroups.collectFirst {
- case (expr, attr) if expr semanticEquals e => attr
- }.getOrElse(e)
- })
-
- val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
-
- val groupKeySchema: StructType = {
- val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
- // This is a dummy field name
- StructField(idx.toString, expr.dataType, expr.nullable)
- }
- StructType(fields)
- }
-
- val schemaSupportsUnsafe: Boolean = {
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
- }
-
- child.execute().mapPartitions { iter =>
- // Builds a new custom class for holding the results of aggregation for a group.
- val initialValues = computeFunctions.flatMap(_.initialValues)
- val newAggregationBuffer = newProjection(initialValues, child.output)
- log.info(s"Initial values: ${initialValues.mkString(",")}")
-
- // A projection that computes the group given an input tuple.
- val groupProjection = newProjection(groupingExpressions, child.output)
- log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
-
- // A projection that is used to update the aggregate values for a group given a new tuple.
- // This projection should be targeted at the current values for the group and then applied
- // to a joined row of the current values with the new input row.
- val updateExpressions = computeFunctions.flatMap(_.update)
- val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
- val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
- log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
-
- // A projection that produces the final result, given a computation.
- val resultProjectionBuilder =
- newMutableProjection(
- resultExpressions,
- namedGroups.map(_._2) ++ computationSchema)
- log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
-
- val joinedRow = new JoinedRow
-
- if (!iter.hasNext) {
- // This is an empty input, so return early so that we do not allocate data structures
- // that won't be cleaned up (see SPARK-8357).
- if (groupingExpressions.isEmpty) {
- // This is a global aggregate, so return an empty aggregation buffer.
- val resultProjection = resultProjectionBuilder()
- Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
- } else {
- // This is a grouped aggregate, so return an empty iterator.
- Iterator[InternalRow]()
- }
- } else if (groupingExpressions.isEmpty) {
- // TODO: Codegening anything other than the updateProjection is probably over kill.
- val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
- var currentRow: InternalRow = null
- updateProjection.target(buffer)
-
- while (iter.hasNext) {
- currentRow = iter.next()
- updateProjection(joinedRow(buffer, currentRow))
- }
-
- val resultProjection = resultProjectionBuilder()
- Iterator(resultProjection(buffer))
-
- } else if (unsafeEnabled && schemaSupportsUnsafe) {
- assert(iter.hasNext, "There should be at least one row for this path")
- log.info("Using Unsafe-based aggregator")
- val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
- val aggregationMap = new UnsafeFixedWidthAggregationMap(
- newAggregationBuffer(EmptyRow),
- aggregationBufferSchema,
- groupKeySchema,
- TaskContext.get.taskMemoryManager(),
- 1024 * 16, // initial capacity
- pageSizeBytes,
- false // disable tracking of performance metrics
- )
-
- while (iter.hasNext) {
- val currentRow: InternalRow = iter.next()
- val groupKey: InternalRow = groupProjection(currentRow)
- val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
- updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
- }
-
- new Iterator[InternalRow] {
- private[this] val mapIterator = aggregationMap.iterator()
- private[this] val resultProjection = resultProjectionBuilder()
-
- def hasNext: Boolean = mapIterator.hasNext
-
- def next(): InternalRow = {
- val entry = mapIterator.next()
- val result = resultProjection(joinedRow(entry.key, entry.value))
- if (hasNext) {
- result
- } else {
- // This is the last element in the iterator, so let's free the buffer. Before we do,
- // though, we need to make a defensive copy of the result so that we don't return an
- // object that might contain dangling pointers to the freed memory
- val resultCopy = result.copy()
- aggregationMap.free()
- resultCopy
- }
- }
- }
- } else {
- if (unsafeEnabled) {
- log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
- }
- val buffers = new java.util.HashMap[InternalRow, MutableRow]()
-
- var currentRow: InternalRow = null
- while (iter.hasNext) {
- currentRow = iter.next()
- val currentGroup = groupProjection(currentRow)
- var currentBuffer = buffers.get(currentGroup)
- if (currentBuffer == null) {
- currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
- buffers.put(currentGroup, currentBuffer)
- }
- // Target the projection at the current aggregation buffer and then project the updated
- // values.
- updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
- }
-
- new Iterator[InternalRow] {
- private[this] val resultIterator = buffers.entrySet.iterator()
- private[this] val resultProjection = resultProjectionBuilder()
-
- def hasNext: Boolean = resultIterator.hasNext
-
- def next(): InternalRow = {
- val currentGroup = resultIterator.next()
- resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
- }
- }
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
new file mode 100644
index 0000000000000..7462dbc4eba3a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.NoSuchElementException
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * An internal iterator interface which presents a more restrictive API than
+ * [[scala.collection.Iterator]].
+ *
+ * One major departure from the Scala iterator API is the fusing of the `hasNext()` and `next()`
+ * calls: Scala's iterator allows users to call `hasNext()` without immediately advancing the
+ * iterator to consume the next row, whereas RowIterator combines these calls into a single
+ * [[advanceNext()]] method.
+ */
+private[sql] abstract class RowIterator {
+ /**
+ * Advance this iterator by a single row. Returns `false` if this iterator has no more rows
+ * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling
+ * [[getRow]].
+ */
+ def advanceNext(): Boolean
+
+ /**
+ * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this
+ * method after [[advanceNext()]] has returned `false`.
+ */
+ def getRow: InternalRow
+
+ /**
+ * Convert this RowIterator into a [[scala.collection.Iterator]].
+ */
+ def toScala: Iterator[InternalRow] = new RowIteratorToScala(this)
+}
+
+object RowIterator {
+ def fromScala(scalaIter: Iterator[InternalRow]): RowIterator = {
+ scalaIter match {
+ case wrappedRowIter: RowIteratorToScala => wrappedRowIter.rowIter
+ case _ => new RowIteratorFromScala(scalaIter)
+ }
+ }
+}
+
+private final class RowIteratorToScala(val rowIter: RowIterator) extends Iterator[InternalRow] {
+ private [this] var hasNextWasCalled: Boolean = false
+ private [this] var _hasNext: Boolean = false
+ override def hasNext: Boolean = {
+ // Idempotency:
+ if (!hasNextWasCalled) {
+ _hasNext = rowIter.advanceNext()
+ hasNextWasCalled = true
+ }
+ _hasNext
+ }
+ override def next(): InternalRow = {
+ if (!hasNext) throw new NoSuchElementException
+ hasNextWasCalled = false
+ rowIter.getRow
+ }
+}
+
+private final class RowIteratorFromScala(scalaIter: Iterator[InternalRow]) extends RowIterator {
+ private[this] var _next: InternalRow = null
+ override def advanceNext(): Boolean = {
+ if (scalaIter.hasNext) {
+ _next = scalaIter.next()
+ true
+ } else {
+ _next = null
+ false
+ }
+ }
+ override def getRow: InternalRow = _next
+ override def toScala: Iterator[InternalRow] = scalaIter
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
new file mode 100644
index 0000000000000..cee58218a885b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.ui.SparkPlanGraph
+import org.apache.spark.util.Utils
+
+private[sql] object SQLExecution {
+
+ val EXECUTION_ID_KEY = "spark.sql.execution.id"
+
+ private val _nextExecutionId = new AtomicLong(0)
+
+ private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
+
+ /**
+ * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that
+ * we can connect them with an execution.
+ */
+ def withNewExecutionId[T](
+ sqlContext: SQLContext, queryExecution: SQLContext#QueryExecution)(body: => T): T = {
+ val sc = sqlContext.sparkContext
+ val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
+ if (oldExecutionId == null) {
+ val executionId = SQLExecution.nextExecutionId
+ sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
+ val r = try {
+ val callSite = Utils.getCallSite()
+ sqlContext.listener.onExecutionStart(
+ executionId,
+ callSite.shortForm,
+ callSite.longForm,
+ queryExecution.toString,
+ SparkPlanGraph(queryExecution.executedPlan),
+ System.currentTimeMillis())
+ try {
+ body
+ } finally {
+ // Ideally, we need to make sure onExecutionEnd happens after onJobStart and onJobEnd.
+ // However, onJobStart and onJobEnd run in the listener thread. Because we cannot add new
+ // SQL event types to SparkListener since it's a public API, we cannot guarantee that.
+ //
+ // SQLListener should handle the case that onExecutionEnd happens before onJobEnd.
+ //
+ // The worst case is onExecutionEnd may happen before onJobStart when the listener thread
+ // is very busy. If so, we cannot track the jobs for the execution. It seems acceptable.
+ sqlContext.listener.onExecutionEnd(executionId, System.currentTimeMillis())
+ }
+ } finally {
+ sc.setLocalProperty(EXECUTION_ID_KEY, null)
+ }
+ r
+ } else {
+ // Don't support nested `withNewExecutionId`. This is an example of the nested
+ // `withNewExecutionId`:
+ //
+ // class DataFrame {
+ // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
+ // }
+ //
+ // Note: `collect` will call withNewExecutionId
+ // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
+ // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
+ // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
+ // all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
+ //
+ // A real case is the `DataFrame.count` method.
+ throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set")
+ }
+ }
+
+ /**
+ * Wrap an action with a known executionId. When running a different action in a different
+ * thread from the original one, this method can be used to connect the Spark jobs in this action
+ * with the known executionId, e.g., `BroadcastHashJoin.broadcastFuture`.
+ */
+ def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
+ val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ try {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
+ body
+ } finally {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2dee3542d6101..e17b50edc62dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -19,9 +19,8 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
@@ -37,61 +36,53 @@ object SortPrefixUtils {
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
- case StringType => PrefixComparators.STRING
- case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
- case FloatType => PrefixComparators.FLOAT
- case DoubleType => PrefixComparators.DOUBLE
+ case StringType =>
+ if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC
+ case BinaryType =>
+ if (sortOrder.isAscending) PrefixComparators.BINARY else PrefixComparators.BINARY_DESC
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType =>
+ if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+ if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
+ case FloatType | DoubleType =>
+ if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC
+ case dt: DecimalType =>
+ if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC
case _ => NoOpPrefixComparator
}
}
- def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
- sortOrder.dataType match {
- case StringType => (row: InternalRow) => {
- PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
+ /**
+ * Creates the prefix comparator for the first field in the given schema, in ascending order.
+ */
+ def getPrefixComparator(schema: StructType): PrefixComparator = {
+ if (schema.nonEmpty) {
+ val field = schema.head
+ getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ } else {
+ new PrefixComparator {
+ override def compare(prefix1: Long, prefix2: Long): Int = 0
}
- case BooleanType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
- else 0
- }
- case ByteType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Byte]
- }
- case ShortType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Short]
- }
- case IntegerType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Int]
- }
- case LongType =>
- (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
- else sortOrder.child.eval(row).asInstanceOf[Long]
+ }
+ }
+
+ /**
+ * Creates the prefix computer for the first field in the given schema, in ascending order.
+ */
+ def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
+ if (schema.nonEmpty) {
+ val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
+ val prefixProjection = UnsafeProjection.create(
+ SortPrefix(SortOrder(boundReference, Ascending)))
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
}
- case FloatType => (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
- else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
}
- case DoubleType => (row: InternalRow) => {
- val exprVal = sortOrder.child.eval(row)
- if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
- else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
+ } else {
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = 0
}
- case _ => (row: InternalRow) => 0L
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 50c27def8ea54..72f5450510a10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.concurrent.atomic.AtomicBoolean
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Logging
@@ -30,6 +32,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric, SQLMetrics}
+import org.apache.spark.sql.types.DataType
object SparkPlan {
protected[sql] val currentContext = new ThreadLocal[SQLContext]()
@@ -52,19 +56,41 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def sparkContext = sqlContext.sparkContext
// sqlContext will be null when we are being deserialized on the slaves. In this instance
- // the value of codegenEnabled will be set by the desserializer after the constructor has run.
+ // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the
+ // constructor has run.
val codegenEnabled: Boolean = if (sqlContext != null) {
sqlContext.conf.codegenEnabled
} else {
false
}
+ val unsafeEnabled: Boolean = if (sqlContext != null) {
+ sqlContext.conf.unsafeEnabled
+ } else {
+ false
+ }
+
+ /**
+ * Whether the "prepare" method is called.
+ */
+ private val prepareCalled = new AtomicBoolean(false)
/** Overridden make copy also propogates sqlContext to copied plan. */
- override def makeCopy(newArgs: Array[AnyRef]): this.type = {
+ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SparkPlan.currentContext.set(sqlContext)
super.makeCopy(newArgs)
}
+ /**
+ * Return all metrics containing metrics of this SparkPlan.
+ */
+ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
+
+ /**
+ * Return a LongSQLMetric according to the name.
+ */
+ private[sql] def longMetric(name: String): LongSQLMetric =
+ metrics(name).asInstanceOf[LongSQLMetric]
+
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
@@ -110,10 +136,31 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
"Operator will receive unsafe rows as input but cannot process unsafe rows")
}
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
+ prepare()
doExecute()
}
}
+ /**
+ * Prepare a SparkPlan for execution. It's idempotent.
+ */
+ final def prepare(): Unit = {
+ if (prepareCalled.compareAndSet(false, true)) {
+ doPrepare()
+ children.foreach(_.prepare())
+ }
+ }
+
+ /**
+ * Overridden by concrete implementations of SparkPlan. It is guaranteed to run before any
+ * `execute` of SparkPlan. This is helpful if we want to set up some state before executing the
+ * query, e.g., `BroadcastHashJoin` uses it to broadcast asynchronously.
+ *
+ * Note: the prepare method has already walked down the tree, so the implementation doesn't need
+ * to call children's prepare methods.
+ */
+ protected def doPrepare(): Unit = {}
+
/**
* Overridden by concrete implementations of SparkPlan.
* Produces the result of the query as an RDD[InternalRow]
@@ -251,12 +298,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
throw e
} else {
log.error("Failed to generate ordering, fallback to interpreted", e)
- new RowOrdering(order, inputSchema)
+ new InterpretedOrdering(order, inputSchema)
}
}
} else {
- new RowOrdering(order, inputSchema)
+ new InterpretedOrdering(order, inputSchema)
+ }
+ }
+ /**
+ * Creates a row ordering for the given schema, in natural ascending order.
+ */
+ protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = {
+ val order: Seq[SortOrder] = dataTypes.zipWithIndex.map {
+ case (dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
}
+ newOrdering(order, Seq.empty)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
deleted file mode 100644
index c808442a4849b..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ /dev/null
@@ -1,425 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io._
-import java.math.{BigDecimal, BigInteger}
-import java.nio.ByteBuffer
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.Logging
-import org.apache.spark.serializer._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-/**
- * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in
- * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the
- * [[Product2]] are constructed based on their schemata.
- * The benefit of this serialization stream is that compared with general-purpose serializers like
- * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower
- * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are:
- * 1. It does not support complex types, i.e. Map, Array, and Struct.
- * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when
- * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because
- * the objects passed in the serializer are not in the type of [[Product2]]. Also also see
- * the comment of the `serializer` method in [[Exchange]] for more information on it.
- */
-private[sql] class Serializer2SerializationStream(
- rowSchema: Array[DataType],
- out: OutputStream)
- extends SerializationStream with Logging {
-
- private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
- private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut)
-
- override def writeObject[T: ClassTag](t: T): SerializationStream = {
- val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]]
- writeKey(kv._1)
- writeValue(kv._2)
-
- this
- }
-
- override def writeKey[T: ClassTag](t: T): SerializationStream = {
- // No-op.
- this
- }
-
- override def writeValue[T: ClassTag](t: T): SerializationStream = {
- writeRowFunc(t.asInstanceOf[InternalRow])
- this
- }
-
- def flush(): Unit = {
- rowOut.flush()
- }
-
- def close(): Unit = {
- rowOut.close()
- }
-}
-
-/**
- * The corresponding deserialization stream for [[Serializer2SerializationStream]].
- */
-private[sql] class Serializer2DeserializationStream(
- rowSchema: Array[DataType],
- in: InputStream)
- extends DeserializationStream with Logging {
-
- private val rowIn = new DataInputStream(new BufferedInputStream(in))
-
- private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
- if (schema == null) {
- () => null
- } else {
- // It is safe to reuse the mutable row.
- val mutableRow = new SpecificMutableRow(schema)
- () => mutableRow
- }
- }
-
- // Functions used to return rows for key and value.
- private val getRow = rowGenerator(rowSchema)
- // Functions used to read a serialized row from the InputStream and deserialize it.
- private val readRowFunc = SparkSqlSerializer2.createDeserializationFunction(rowSchema, rowIn)
-
- override def readObject[T: ClassTag](): T = {
- readValue()
- }
-
- override def readKey[T: ClassTag](): T = {
- null.asInstanceOf[T] // intentionally left blank.
- }
-
- override def readValue[T: ClassTag](): T = {
- readRowFunc(getRow()).asInstanceOf[T]
- }
-
- override def close(): Unit = {
- rowIn.close()
- }
-}
-
-private[sql] class SparkSqlSerializer2Instance(
- rowSchema: Array[DataType])
- extends SerializerInstance {
-
- def serialize[T: ClassTag](t: T): ByteBuffer =
- throw new UnsupportedOperationException("Not supported.")
-
- def deserialize[T: ClassTag](bytes: ByteBuffer): T =
- throw new UnsupportedOperationException("Not supported.")
-
- def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
- throw new UnsupportedOperationException("Not supported.")
-
- def serializeStream(s: OutputStream): SerializationStream = {
- new Serializer2SerializationStream(rowSchema, s)
- }
-
- def deserializeStream(s: InputStream): DeserializationStream = {
- new Serializer2DeserializationStream(rowSchema, s)
- }
-}
-
-/**
- * SparkSqlSerializer2 is a special serializer that creates serialization function and
- * deserialization function based on the schema of data. It assumes that values passed in
- * are Rows.
- */
-private[sql] class SparkSqlSerializer2(rowSchema: Array[DataType])
- extends Serializer
- with Logging
- with Serializable{
-
- def newInstance(): SerializerInstance = new SparkSqlSerializer2Instance(rowSchema)
-
- override def supportsRelocationOfSerializedObjects: Boolean = {
- // SparkSqlSerializer2 is stateless and writes no stream headers
- true
- }
-}
-
-private[sql] object SparkSqlSerializer2 {
-
- final val NULL = 0
- final val NOT_NULL = 1
-
- /**
- * Check if rows with the given schema can be serialized with ShuffleSerializer.
- * Right now, we do not support a schema having complex types or UDTs, or all data types
- * of fields are NullTypes.
- */
- def support(schema: Array[DataType]): Boolean = {
- if (schema == null) return true
-
- var allNullTypes = true
- var i = 0
- while (i < schema.length) {
- schema(i) match {
- case NullType => // Do nothing
- case udt: UserDefinedType[_] =>
- allNullTypes = false
- return false
- case array: ArrayType =>
- allNullTypes = false
- return false
- case map: MapType =>
- allNullTypes = false
- return false
- case struct: StructType =>
- allNullTypes = false
- return false
- case _ =>
- allNullTypes = false
- }
- i += 1
- }
-
- // If types of fields are all NullTypes, we return false.
- // Otherwise, we return true.
- return !allNullTypes
- }
-
- /**
- * The util function to create the serialization function based on the given schema.
- */
- def createSerializationFunction(schema: Array[DataType], out: DataOutputStream)
- : InternalRow => Unit = {
- (row: InternalRow) =>
- // If the schema is null, the returned function does nothing when it get called.
- if (schema != null) {
- var i = 0
- while (i < schema.length) {
- schema(i) match {
- // When we write values to the underlying stream, we also first write the null byte
- // first. Then, if the value is not null, we write the contents out.
-
- case NullType => // Write nothing.
-
- case BooleanType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeBoolean(row.getBoolean(i))
- }
-
- case ByteType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeByte(row.getByte(i))
- }
-
- case ShortType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeShort(row.getShort(i))
- }
-
- case IntegerType | DateType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeInt(row.getInt(i))
- }
-
- case LongType | TimestampType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeLong(row.getLong(i))
- }
-
- case FloatType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeFloat(row.getFloat(i))
- }
-
- case DoubleType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- out.writeDouble(row.getDouble(i))
- }
-
- case StringType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- val bytes = row.getUTF8String(i).getBytes
- out.writeInt(bytes.length)
- out.write(bytes)
- }
-
- case BinaryType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- val bytes = row.getBinary(i)
- out.writeInt(bytes.length)
- out.write(bytes)
- }
-
- case decimal: DecimalType =>
- if (row.isNullAt(i)) {
- out.writeByte(NULL)
- } else {
- out.writeByte(NOT_NULL)
- val value = row.getDecimal(i)
- val javaBigDecimal = value.toJavaBigDecimal
- // First, write out the unscaled value.
- val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
- out.writeInt(bytes.length)
- out.write(bytes)
- // Then, write out the scale.
- out.writeInt(javaBigDecimal.scale())
- }
- }
- i += 1
- }
- }
- }
-
- /**
- * The util function to create the deserialization function based on the given schema.
- */
- def createDeserializationFunction(
- schema: Array[DataType],
- in: DataInputStream): (MutableRow) => InternalRow = {
- if (schema == null) {
- (mutableRow: MutableRow) => null
- } else {
- (mutableRow: MutableRow) => {
- var i = 0
- while (i < schema.length) {
- schema(i) match {
- // When we read values from the underlying stream, we also first read the null byte
- // first. Then, if the value is not null, we update the field of the mutable row.
-
- case NullType => mutableRow.setNullAt(i) // Read nothing.
-
- case BooleanType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setBoolean(i, in.readBoolean())
- }
-
- case ByteType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setByte(i, in.readByte())
- }
-
- case ShortType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setShort(i, in.readShort())
- }
-
- case IntegerType | DateType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setInt(i, in.readInt())
- }
-
- case LongType | TimestampType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setLong(i, in.readLong())
- }
-
- case FloatType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setFloat(i, in.readFloat())
- }
-
- case DoubleType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- mutableRow.setDouble(i, in.readDouble())
- }
-
- case StringType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- val length = in.readInt()
- val bytes = new Array[Byte](length)
- in.readFully(bytes)
- mutableRow.update(i, UTF8String.fromBytes(bytes))
- }
-
- case BinaryType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- val length = in.readInt()
- val bytes = new Array[Byte](length)
- in.readFully(bytes)
- mutableRow.update(i, bytes)
- }
-
- case decimal: DecimalType =>
- if (in.readByte() == NULL) {
- mutableRow.setNullAt(i)
- } else {
- // First, read in the unscaled value.
- val length = in.readInt()
- val bytes = new Array[Byte](length)
- in.readFully(bytes)
- val unscaledVal = new BigInteger(bytes)
- // Then, read the scale.
- val scale = in.readInt()
- // Finally, create the Decimal object and set it in the row.
- mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
- }
- }
- i += 1
- }
-
- mutableRow
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 52a9b02d373c7..1fc870d44b578 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -63,19 +63,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be
- * evaluated by matching hash keys.
+ * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates
+ * can be evaluated by matching join keys.
*
- * This strategy applies a simple optimization based on the estimates of the physical sizes of
- * the two join sides. When planning a [[joins.BroadcastHashJoin]], if one side has an
- * estimated physical size smaller than the user-settable threshold
- * [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
- * ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
- * ''broadcasted'' to all of the executors involved in the join, as a
- * [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
- * will instead be used to decide the build side in a [[joins.ShuffledHashJoin]].
+ * Join implementations are chosen with the following precedence:
+ *
+ * - Broadcast: if one side of the join has an estimated physical size that is smaller than the
+ * user-configurable [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
+ * or if that side has an explicit broadcast hint (e.g. the user applied the
+ * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
+ * of the join will be broadcasted and the other side will be streamed, with no shuffling
+ * performed. If both sides of the join are eligible to be broadcasted then the
+ * - Sort merge: if the matching join keys are sortable and
+ * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join
+ * will be used.
+ * - Hash: will be chosen if neither of the above optimizations apply to this join.
*/
- object HashJoin extends Strategy with PredicateHelper {
+ object EquiJoinSelection extends Strategy with PredicateHelper {
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
@@ -89,29 +93,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
- private[this] def isValidSort(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression]): Boolean = {
- leftKeys.zip(rightKeys).forall { keys =>
- (keys._1.dataType, keys._2.dataType) match {
- case (l: AtomicType, r: AtomicType) => true
- case (NullType, NullType) => true
- case _ => false
- }
- }
- }
-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+
+ // --- Inner joins --------------------------------------------------------------------------
+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
- // If the sort merge join option is set, we want to use sort merge join prior to hashjoin
- // for now let's support inner join first, then add outer join
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
- if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) =>
+ if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
@@ -127,6 +120,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
+ // --- Outer joins --------------------------------------------------------------------------
+
case ExtractEquiJoinKeys(
LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
joins.BroadcastHashOuterJoin(
@@ -137,10 +132,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+ case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
+ joins.SortMergeOuterJoin(
+ leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+
+ case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
+ if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
+ joins.SortMergeOuterJoin(
+ leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.ShuffledHashOuterJoin(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
+ // --- Cases where this strategy does not apply ---------------------------------------------
+
case _ => Nil
}
}
@@ -148,32 +155,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Aggregations that can be performed in two phases, before and after the shuffle.
-
- // Cases where all aggregates can be codegened.
- case PartialAggregation(
- namedGroupingAttributes,
- rewrittenAggregateExpressions,
- groupingExpressions,
- partialComputation,
- child)
- if canBeCodeGened(
- allAggregates(partialComputation) ++
- allAggregates(rewrittenAggregateExpressions)) &&
- codegenEnabled &&
- !canBeConvertedToNewAggregation(plan) =>
- execution.GeneratedAggregate(
- partial = false,
- namedGroupingAttributes,
- rewrittenAggregateExpressions,
- unsafeEnabled,
- execution.GeneratedAggregate(
- partial = true,
- groupingExpressions,
- partialComputation,
- unsafeEnabled,
- planLater(child))) :: Nil
-
- // Cases where some aggregate can not be codegened
case PartialAggregation(
namedGroupingAttributes,
rewrittenAggregateExpressions,
@@ -204,14 +185,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => false
}
- def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
- case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
- // The generated set implementation is pretty limited ATM.
- case CollectHashSet(exprs) if exprs.size == 1 &&
- Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
- case _ => false
- }
-
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
}
@@ -237,8 +210,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
+ val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) ->
- Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
+ (aggregateFunction -> attribtue)
}.toMap
val (functionsWithDistinct, functionsWithoutDistinct) =
@@ -341,8 +315,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled &&
- UnsafeExternalSort.supportsSchema(child.schema)) {
- execution.UnsafeExternalSort(sortExprs, global, child)
+ TungstenSort.supportsSchema(child.schema)) {
+ execution.TungstenSort(sortExprs, global, child)
} else if (sqlContext.conf.externalSortEnabled) {
execution.ExternalSort(sortExprs, global, child)
} else {
@@ -389,8 +363,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
- case logical.Window(projectList, windowExpressions, spec, child) =>
- execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil
+ case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
+ execution.Window(
+ projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
@@ -407,12 +382,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Generate(
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
- execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
+ execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil
case logical.RepartitionByExpression(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
case e @ EvaluatePython(udf, child, _) =>
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
- case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil
+ case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil
case BroadcastHint(child) => apply(child)
case _ => Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 16498da080c88..5c18558f9bde7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
@@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.Platform
/**
* Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as
@@ -58,12 +58,14 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
- private[this] val dOut: DataOutputStream = new DataOutputStream(out)
+ private[this] val dOut: DataOutputStream =
+ new DataOutputStream(new BufferedOutputStream(out))
override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
+
dOut.writeInt(row.getSizeInBytes)
- row.writeToStream(out, writeBuffer)
+ row.writeToStream(dOut, writeBuffer)
this
}
@@ -97,7 +99,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def deserializeStream(in: InputStream): DeserializationStream = {
new DeserializationStream {
- private[this] val dIn: DataInputStream = new DataInputStream(in)
+ private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in))
// 1024 is a default buffer size; this buffer will grow to accommodate larger rows
private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
private[this] var row: UnsafeRow = new UnsafeRow()
@@ -106,6 +108,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
new Iterator[(Int, UnsafeRow)] {
private[this] var rowSize: Int = dIn.readInt()
+ if (rowSize == EOF) dIn.close()
override def hasNext: Boolean = rowSize != EOF
@@ -113,10 +116,11 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
- ByteStreams.readFully(in, rowBuffer, 0, rowSize)
- row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
+ ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
+ row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
rowSize = dIn.readInt() // read the next row's size
if (rowSize == EOF) { // We are returning the last row in this stream
+ dIn.close()
val _rowTuple = rowTuple
// Null these out so that the byte array can be garbage collected once the entire
// iterator has been consumed
@@ -147,8 +151,8 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
if (rowBuffer.length < rowSize) {
rowBuffer = new Array[Byte](rowSize)
}
- ByteStreams.readFully(in, rowBuffer, 0, rowSize)
- row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize)
+ ByteStreams.readFully(dIn, rowBuffer, 0, rowSize)
+ row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize)
row.asInstanceOf[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 91c8a02e2b5bc..0269d6d4b7a1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -80,26 +80,29 @@ import scala.collection.mutable
case class Window(
projectList: Seq[Attribute],
windowExpression: Seq[NamedExpression],
- windowSpec: WindowSpecDefinition,
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
child: SparkPlan)
extends UnaryNode {
override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute)
override def requiredChildDistribution: Seq[Distribution] = {
- if (windowSpec.partitionSpec.isEmpty) {
+ if (partitionSpec.isEmpty) {
// Only show warning when the number of bytes is larger than 100 MB?
logWarning("No Partition Defined for Window operation! Moving all data to a single "
+ "partition, this can cause serious performance degradation.")
AllTuples :: Nil
- } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil
+ } else ClusteredDistribution(partitionSpec) :: Nil
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- Seq(windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec)
+ Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def canProcessUnsafeRows: Boolean = true
+
/**
* Create a bound ordering object for a given frame type and offset. A bound ordering object is
* used to determine which input row lies within the frame boundaries of an output row.
@@ -115,12 +118,12 @@ case class Window(
case RangeFrame =>
val (exprs, current, bound) = if (offset == 0) {
// Use the entire order expression when the offset is 0.
- val exprs = windowSpec.orderSpec.map(_.child)
+ val exprs = orderSpec.map(_.child)
val projection = newMutableProjection(exprs, child.output)
- (windowSpec.orderSpec, projection(), projection())
- } else if (windowSpec.orderSpec.size == 1) {
+ (orderSpec, projection(), projection())
+ } else if (orderSpec.size == 1) {
// Use only the first order expression when the offset is non-null.
- val sortExpr = windowSpec.orderSpec.head
+ val sortExpr = orderSpec.head
val expr = sortExpr.child
// Create the projection which returns the current 'value'.
val current = newMutableProjection(expr :: Nil, child.output)()
@@ -250,7 +253,7 @@ case class Window(
// Get all relevant projections.
val result = createResultProjection(unboundExpressions)
- val grouping = newProjection(windowSpec.partitionSpec, child.output)
+ val grouping = newProjection(partitionSpec, child.output)
// Manage the stream and the grouping.
var nextRow: InternalRow = EmptyRow
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
new file mode 100644
index 0000000000000..abca373b0c4f9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.unsafe.KVIterator
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]].
+ * It mainly contains two parts:
+ * 1. It initializes aggregate functions.
+ * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of
+ * its aggregate functions. `processRow` is the function to handle an input. `generateOutput`
+ * is used to generate result.
+ */
+abstract class AggregationIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean)
+ extends Iterator[InternalRow] with Logging {
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Initializing functions.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // An Seq of all AggregateExpressions.
+ // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
+ // are at the beginning of the allAggregateExpressions.
+ protected val allAggregateExpressions =
+ nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ require(
+ allAggregateExpressions.map(_.mode).distinct.length <= 2,
+ s"$allAggregateExpressions are not supported becuase they have more than 2 distinct modes.")
+
+ /**
+ * The distinct modes of AggregateExpressions. Right now, we can handle the following mode:
+ * - Partial-only: all AggregateExpressions have the mode of Partial;
+ * - PartialMerge-only: all AggregateExpressions have the mode of PartialMerge);
+ * - Final-only: all AggregateExpressions have the mode of Final;
+ * - Final-Complete: some AggregateExpressions have the mode of Final and
+ * others have the mode of Complete;
+ * - Complete-only: nonCompleteAggregateExpressions is empty and we have AggregateExpressions
+ * with mode Complete in completeAggregateExpressions; and
+ * - Grouping-only: there is no AggregateExpression.
+ */
+ protected val aggregationMode: (Option[AggregateMode], Option[AggregateMode]) =
+ nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+ completeAggregateExpressions.map(_.mode).distinct.headOption
+
+ // Initialize all AggregateFunctions by binding references if necessary,
+ // and set inputBufferOffset and mutableBufferOffset.
+ protected val allAggregateFunctions: Array[AggregateFunction2] = {
+ var mutableBufferOffset = 0
+ var inputBufferOffset: Int = initialInputBufferOffset
+ val functions = new Array[AggregateFunction2](allAggregateExpressions.length)
+ var i = 0
+ while (i < allAggregateExpressions.length) {
+ val func = allAggregateExpressions(i).aggregateFunction
+ val funcWithBoundReferences = allAggregateExpressions(i).mode match {
+ case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+ // We need to create BoundReferences if the function is not an
+ // AlgebraicAggregate (it does not support code-gen) and the mode of
+ // this function is Partial or Complete because we will call eval of this
+ // function's children in the update method of this aggregate function.
+ // Those eval calls require BoundReferences to work.
+ BindReferences.bindReference(func, valueAttributes)
+ case _ =>
+ // We only need to set inputBufferOffset for aggregate functions with mode
+ // PartialMerge and Final.
+ func.withNewInputBufferOffset(inputBufferOffset)
+ inputBufferOffset += func.bufferSchema.length
+ func
+ }
+ // Set mutableBufferOffset for this function. It is important that setting
+ // mutableBufferOffset happens after all potential bindReference operations
+ // because bindReference will create a new instance of the function.
+ funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset)
+ mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
+ functions(i) = funcWithBoundReferences
+ i += 1
+ }
+ functions
+ }
+
+ // Positions of those non-algebraic aggregate functions in allAggregateFunctions.
+ // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+ // func2 and func3 are non-algebraic aggregate functions.
+ // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+ private[this] val allNonAlgebraicAggregateFunctionPositions: Array[Int] = {
+ val positions = new ArrayBuffer[Int]()
+ var i = 0
+ while (i < allAggregateFunctions.length) {
+ allAggregateFunctions(i) match {
+ case agg: AlgebraicAggregate =>
+ case _ => positions += i
+ }
+ i += 1
+ }
+ positions.toArray
+ }
+
+ // All AggregateFunctions functions with mode Partial, PartialMerge, or Final.
+ private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+
+ // All non-algebraic aggregate functions with mode Partial, PartialMerge, or Final.
+ private[this] val nonCompleteNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ nonCompleteAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }
+
+ // The projection used to initialize buffer values for all AlgebraicAggregates.
+ private[this] val algebraicInitialProjection = {
+ val initExpressions = allAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.initialValues
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ newMutableProjection(initExpressions, Nil)()
+ }
+
+ // All non-Algebraic AggregateFunctions.
+ private[this] val allNonAlgebraicAggregateFunctions =
+ allNonAlgebraicAggregateFunctionPositions.map(allAggregateFunctions)
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Methods and fields used by sub-classes.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // Initializing functions used to process a row.
+ protected val processRow: (MutableRow, InternalRow) => Unit = {
+ val rowToBeProcessed = new JoinedRow
+ val aggregationBufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ aggregationMode match {
+ // Partial-only
+ case (Some(Partial), None) =>
+ val updateExpressions = nonCompleteAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ val algebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ algebraicUpdateProjection.target(currentBuffer)
+ // Process all algebraic aggregate functions.
+ algebraicUpdateProjection(rowToBeProcessed(currentBuffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+ nonCompleteNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // PartialMerge-only or Final-only
+ case (Some(PartialMerge), None) | (Some(Final), None) =>
+ val inputAggregationBufferSchema = if (initialInputBufferOffset == 0) {
+ // If initialInputBufferOffset, the input value does not contain
+ // grouping keys.
+ // This part is pretty hacky.
+ allAggregateFunctions.flatMap(_.cloneBufferAttributes).toSeq
+ } else {
+ groupingKeyAttributes ++ allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ }
+ // val inputAggregationBufferSchema =
+ // groupingKeyAttributes ++
+ // allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions = nonCompleteAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ val algebraicMergeProjection =
+ newMutableProjection(
+ mergeExpressions,
+ aggregationBufferSchema ++ inputAggregationBufferSchema)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(currentBuffer)(rowToBeProcessed(currentBuffer, row))
+ // Process all non-algebraic aggregate functions.
+ var i = 0
+ while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+ nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // Final-Complete
+ case (Some(Final), Some(Complete)) =>
+ val completeAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ // All non-algebraic aggregate functions with mode Complete.
+ val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ completeAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }
+
+ // The first initialInputBufferOffset values of the input aggregation buffer is
+ // for grouping expressions and distinct columns.
+ val groupingAttributesAndDistinctColumns = valueAttributes.take(initialInputBufferOffset)
+
+ val completeOffsetExpressions =
+ Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ // We do not touch buffer values of aggregate functions with the Final mode.
+ val finalOffsetExpressions =
+ Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+
+ val mergeInputSchema =
+ aggregationBufferSchema ++
+ groupingAttributesAndDistinctColumns ++
+ nonCompleteAggregateFunctions.flatMap(_.cloneBufferAttributes)
+ val mergeExpressions =
+ nonCompleteAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.mergeExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ } ++ completeOffsetExpressions
+ val finalAlgebraicMergeProjection =
+ newMutableProjection(mergeExpressions, mergeInputSchema)()
+
+ val updateExpressions =
+ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ val input = rowToBeProcessed(currentBuffer, row)
+ // For all aggregate functions with mode Complete, update buffers.
+ completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ var i = 0
+ while (i < completeNonAlgebraicAggregateFunctions.length) {
+ completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
+
+ // For all aggregate functions with mode Final, merge buffers.
+ finalAlgebraicMergeProjection.target(currentBuffer)(input)
+ i = 0
+ while (i < nonCompleteNonAlgebraicAggregateFunctions.length) {
+ nonCompleteNonAlgebraicAggregateFunctions(i).merge(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // Complete-only
+ case (None, Some(Complete)) =>
+ val completeAggregateFunctions: Array[AggregateFunction2] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+ // All non-algebraic aggregate functions with mode Complete.
+ val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ completeAggregateFunctions.collect {
+ case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+ }
+
+ val updateExpressions =
+ completeAggregateFunctions.flatMap {
+ case ae: AlgebraicAggregate => ae.updateExpressions
+ case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+ }
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)()
+
+ (currentBuffer: MutableRow, row: InternalRow) => {
+ val input = rowToBeProcessed(currentBuffer, row)
+ // For all aggregate functions with mode Complete, update buffers.
+ completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+ var i = 0
+ while (i < completeNonAlgebraicAggregateFunctions.length) {
+ completeNonAlgebraicAggregateFunctions(i).update(currentBuffer, row)
+ i += 1
+ }
+ }
+
+ // Grouping only.
+ case (None, None) => (currentBuffer: MutableRow, row: InternalRow) => {}
+
+ case other =>
+ sys.error(
+ s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
+ s"support evaluate modes $other in this iterator.")
+ }
+ }
+
+ // Initializing the function used to generate the output row.
+ protected val generateOutput: (InternalRow, MutableRow) => InternalRow = {
+ val rowToBeEvaluated = new JoinedRow
+ val safeOutoutRow = new GenericMutableRow(resultExpressions.length)
+ val mutableOutput = if (outputsUnsafeRows) {
+ UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutoutRow)
+ } else {
+ safeOutoutRow
+ }
+
+ aggregationMode match {
+ // Partial-only or PartialMerge-only: every output row is basically the values of
+ // the grouping expressions and the corresponding aggregation buffer.
+ case (Some(Partial), None) | (Some(PartialMerge), None) =>
+ // Because we cannot copy a joinedRow containing a UnsafeRow (UnsafeRow does not
+ // support generic getter), we create a mutable projection to output the
+ // JoinedRow(currentGroupingKey, currentBuffer)
+ val bufferSchema = nonCompleteAggregateFunctions.flatMap(_.bufferAttributes)
+ val resultProjection =
+ newMutableProjection(
+ groupingKeyAttributes ++ bufferSchema,
+ groupingKeyAttributes ++ bufferSchema)()
+ resultProjection.target(mutableOutput)
+
+ (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+ resultProjection(rowToBeEvaluated(currentGroupingKey, currentBuffer))
+ // rowToBeEvaluated(currentGroupingKey, currentBuffer)
+ }
+
+ // Final-only, Complete-only and Final-Complete: every output row contains values representing
+ // resultExpressions.
+ case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+ val bufferSchemata =
+ allAggregateFunctions.flatMap(_.bufferAttributes)
+ val evalExpressions = allAggregateFunctions.map {
+ case ae: AlgebraicAggregate => ae.evaluateExpression
+ case agg: AggregateFunction2 => NoOp
+ }
+ val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
+ val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
+ // TODO: Use unsafe row.
+ val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
+ val resultProjection =
+ newMutableProjection(
+ resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()
+ resultProjection.target(mutableOutput)
+
+ (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+ // Generate results for all algebraic aggregate functions.
+ algebraicEvalProjection.target(aggregateResult)(currentBuffer)
+ // Generate results for all non-algebraic aggregate functions.
+ var i = 0
+ while (i < allNonAlgebraicAggregateFunctions.length) {
+ aggregateResult.update(
+ allNonAlgebraicAggregateFunctionPositions(i),
+ allNonAlgebraicAggregateFunctions(i).eval(currentBuffer))
+ i += 1
+ }
+ resultProjection(rowToBeEvaluated(currentGroupingKey, aggregateResult))
+ }
+
+ // Grouping-only: we only output values of grouping expressions.
+ case (None, None) =>
+ val resultProjection =
+ newMutableProjection(resultExpressions, groupingKeyAttributes)()
+ resultProjection.target(mutableOutput)
+
+ (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => {
+ resultProjection(currentGroupingKey)
+ }
+
+ case other =>
+ sys.error(
+ s"Could not evaluate ${nonCompleteAggregateExpressions} because we do not " +
+ s"support evaluate modes $other in this iterator.")
+ }
+ }
+
+ /** Initializes buffer values for all aggregate functions. */
+ protected def initializeBuffer(buffer: MutableRow): Unit = {
+ algebraicInitialProjection.target(buffer)(EmptyRow)
+ var i = 0
+ while (i < allNonAlgebraicAggregateFunctions.length) {
+ allNonAlgebraicAggregateFunctions(i).initialize(buffer)
+ i += 1
+ }
+ }
+
+ /**
+ * Creates a new aggregation buffer and initializes buffer values
+ * for all aggregate functions.
+ */
+ protected def newBuffer: MutableRow
+}
+
+object AggregationIterator {
+ def kvIterator(
+ groupingExpressions: Seq[NamedExpression],
+ newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = {
+ new KVIterator[InternalRow, InternalRow] {
+ private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes)
+
+ private[this] var groupingKey: InternalRow = _
+
+ private[this] var value: InternalRow = _
+
+ override def next(): Boolean = {
+ if (inputIter.hasNext) {
+ // Read the next input row.
+ val inputRow = inputIter.next()
+ // Get groupingKey based on groupingExpressions.
+ groupingKey = groupingKeyGenerator(inputRow)
+ // The value is the inputRow.
+ value = inputRow
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey(): InternalRow = {
+ groupingKey
+ }
+
+ override def getValue(): InternalRow = {
+ value
+ }
+
+ override def close(): Unit = {
+ // Do nothing
+ }
+ }
+ }
+
+ def unsafeKVIterator(
+ groupingExpressions: Seq[NamedExpression],
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = {
+ new KVIterator[UnsafeRow, InternalRow] {
+ private[this] val groupingKeyGenerator =
+ UnsafeProjection.create(groupingExpressions, inputAttributes)
+
+ private[this] var groupingKey: UnsafeRow = _
+
+ private[this] var value: InternalRow = _
+
+ override def next(): Boolean = {
+ if (inputIter.hasNext) {
+ // Read the next input row.
+ val inputRow = inputIter.next()
+ // Get groupingKey based on groupingExpressions.
+ groupingKey = groupingKeyGenerator.apply(inputRow)
+ // The value is the inputRow.
+ value = inputRow
+ true
+ } else {
+ false
+ }
+ }
+
+ override def getKey(): UnsafeRow = {
+ groupingKey
+ }
+
+ override def getValue(): InternalRow = {
+ value
+ }
+
+ override def close(): Unit = {
+ // Do nothing
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
new file mode 100644
index 0000000000000..f4c14a9b3556f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.StructType
+
+case class SortBasedAggregate(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override private[sql] lazy val metrics = Map(
+ "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def outputsUnsafeRows: Boolean = false
+
+ override def canProcessUnsafeRows: Boolean = false
+
+ override def canProcessSafeRows: Boolean = true
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ groupingExpressions.map(SortOrder(_, Ascending))
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ val numInputRows = longMetric("numInputRows")
+ val numOutputRows = longMetric("numOutputRows")
+ child.execute().mapPartitions { iter =>
+ // Because the constructor of an aggregation iterator will read at least the first row,
+ // we need to get the value of iter.hasNext first.
+ val hasInput = iter.hasNext
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[InternalRow]()
+ } else {
+ val outputIter = SortBasedAggregationIterator.createFromInputIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection _,
+ newProjection _,
+ child.output,
+ iter,
+ outputsUnsafeRows,
+ numInputRows,
+ numOutputRows)
+ if (!hasInput && groupingExpressions.isEmpty) {
+ // There is no input and there is no grouping expressions.
+ // We need to output a single row as the output.
+ numOutputRows += 1
+ Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ outputIter
+ }
+ }
+ }
+ }
+
+ override def simpleString: String = {
+ val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ val keyString = groupingExpressions.mkString("[", ",", "]")
+ val functionString = allAggregateExpressions.mkString("[", ",", "]")
+ val outputString = output.mkString("[", ",", "]")
+ s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)"
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
new file mode 100644
index 0000000000000..73d50e07cf0b5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2}
+import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.unsafe.KVIterator
+
+/**
+ * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been
+ * sorted by values of [[groupingKeyAttributes]].
+ */
+class SortBasedAggregationIterator(
+ groupingKeyAttributes: Seq[Attribute],
+ valueAttributes: Seq[Attribute],
+ inputKVIterator: KVIterator[InternalRow, InternalRow],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ outputsUnsafeRows: Boolean,
+ numInputRows: LongSQLMetric,
+ numOutputRows: LongSQLMetric)
+ extends AggregationIterator(
+ groupingKeyAttributes,
+ valueAttributes,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows) {
+
+ override protected def newBuffer: MutableRow = {
+ val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferRowSize: Int = bufferSchema.length
+
+ val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+ val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
+
+ val buffer = if (useUnsafeBuffer) {
+ val unsafeProjection =
+ UnsafeProjection.create(bufferSchema.map(_.dataType))
+ unsafeProjection.apply(genericMutableBuffer)
+ } else {
+ genericMutableBuffer
+ }
+ initializeBuffer(buffer)
+ buffer
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Mutable states for sort based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // The partition key of the current partition.
+ private[this] var currentGroupingKey: InternalRow = _
+
+ // The partition key of next partition.
+ private[this] var nextGroupingKey: InternalRow = _
+
+ // The first row of next partition.
+ private[this] var firstRowInNextGroup: InternalRow = _
+
+ // Indicates if we has new group of rows from the sorted input iterator
+ private[this] var sortedInputHasNewGroup: Boolean = false
+
+ // The aggregation buffer used by the sort-based aggregation.
+ private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
+
+ /** Processes rows in the current group. It will stop when it find a new group. */
+ protected def processCurrentSortedGroup(): Unit = {
+ currentGroupingKey = nextGroupingKey
+ // Now, we will start to find all rows belonging to this group.
+ // We create a variable to track if we see the next group.
+ var findNextPartition = false
+ // firstRowInNextGroup is the first row of this group. We first process it.
+ processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+ // The search will stop when we see the next group or there is no
+ // input row left in the iter.
+ var hasNext = inputKVIterator.next()
+ while (!findNextPartition && hasNext) {
+ // Get the grouping key.
+ val groupingKey = inputKVIterator.getKey
+ val currentRow = inputKVIterator.getValue
+ numInputRows += 1
+
+ // Check if the current row belongs the current input row.
+ if (currentGroupingKey == groupingKey) {
+ processRow(sortBasedAggregationBuffer, currentRow)
+
+ hasNext = inputKVIterator.next()
+ } else {
+ // We find a new group.
+ findNextPartition = true
+ nextGroupingKey = groupingKey.copy()
+ firstRowInNextGroup = currentRow.copy()
+ }
+ }
+ // We have not seen a new group. It means that there is no new row in the input
+ // iter. The current group is the last group of the iter.
+ if (!findNextPartition) {
+ sortedInputHasNewGroup = false
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Iterator's public methods
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = sortedInputHasNewGroup
+
+ override final def next(): InternalRow = {
+ if (hasNext) {
+ // Process the current group.
+ processCurrentSortedGroup()
+ // Generate output row for the current group.
+ val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
+ // Initialize buffer values for the next group.
+ initializeBuffer(sortBasedAggregationBuffer)
+ numOutputRows += 1
+ outputRow
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ protected def initialize(): Unit = {
+ if (inputKVIterator.next()) {
+ initializeBuffer(sortBasedAggregationBuffer)
+
+ nextGroupingKey = inputKVIterator.getKey().copy()
+ firstRowInNextGroup = inputKVIterator.getValue().copy()
+ numInputRows += 1
+ sortedInputHasNewGroup = true
+ } else {
+ // This inputIter is empty.
+ sortedInputHasNewGroup = false
+ }
+ }
+
+ initialize()
+
+ def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
+ initializeBuffer(sortBasedAggregationBuffer)
+ generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
+ }
+}
+
+object SortBasedAggregationIterator {
+ // scalastyle:off
+ def createFromInputIterator(
+ groupingExprs: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ nonCompleteAggregateAttributes: Seq[Attribute],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ newProjection: (Seq[Expression], Seq[Attribute]) => Projection,
+ inputAttributes: Seq[Attribute],
+ inputIter: Iterator[InternalRow],
+ outputsUnsafeRows: Boolean,
+ numInputRows: LongSQLMetric,
+ numOutputRows: LongSQLMetric): SortBasedAggregationIterator = {
+ val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
+ AggregationIterator.unsafeKVIterator(
+ groupingExprs,
+ inputAttributes,
+ inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]]
+ } else {
+ AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter)
+ }
+
+ new SortBasedAggregationIterator(
+ groupingExprs.map(_.toAttribute),
+ inputAttributes,
+ kvIterator,
+ nonCompleteAggregateExpressions,
+ nonCompleteAggregateAttributes,
+ completeAggregateExpressions,
+ completeAggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ outputsUnsafeRows,
+ numInputRows,
+ numOutputRows)
+ }
+ // scalastyle:on
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
new file mode 100644
index 0000000000000..99f51ba5b6935
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+case class TungstenAggregate(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryNode {
+
+ override private[sql] lazy val metrics = Map(
+ "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def outputsUnsafeRows: Boolean = true
+
+ override def canProcessUnsafeRows: Boolean = true
+
+ override def canProcessSafeRows: Boolean = true
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ // This is for testing. We force TungstenAggregationIterator to fall back to sort-based
+ // aggregation once it has processed a given number of input rows.
+ private val testFallbackStartsAt: Option[Int] = {
+ sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match {
+ case null | "" => None
+ case fallbackStartsAt => Some(fallbackStartsAt.toInt)
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ val numInputRows = longMetric("numInputRows")
+ val numOutputRows = longMetric("numOutputRows")
+
+ /**
+ * Set up the underlying unsafe data structures used before computing the parent partition.
+ * This makes sure our iterator is not starved by other operators in the same task.
+ */
+ def preparePartition(): TungstenAggregationIterator = {
+ new TungstenAggregationIterator(
+ groupingExpressions,
+ nonCompleteAggregateExpressions,
+ completeAggregateExpressions,
+ initialInputBufferOffset,
+ resultExpressions,
+ newMutableProjection,
+ child.output,
+ testFallbackStartsAt,
+ numInputRows,
+ numOutputRows)
+ }
+
+ /** Compute a partition using the iterator already set up previously. */
+ def executePartition(
+ context: TaskContext,
+ partitionIndex: Int,
+ aggregationIterator: TungstenAggregationIterator,
+ parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
+ val hasInput = parentIterator.hasNext
+ if (!hasInput) {
+ // We're not using the underlying map, so we just can free it here
+ aggregationIterator.free()
+ if (groupingExpressions.isEmpty) {
+ numOutputRows += 1
+ Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[UnsafeRow]()
+ }
+ } else {
+ aggregationIterator.start(parentIterator)
+ aggregationIterator
+ }
+ }
+
+ // Note: we need to set up the iterator in each partition before computing the
+ // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
+ val resultRdd = {
+ new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
+ child.execute(), preparePartition, executePartition, preservesPartitioning = true)
+ }
+ resultRdd.asInstanceOf[RDD[InternalRow]]
+ }
+
+ override def simpleString: String = {
+ val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ testFallbackStartsAt match {
+ case None =>
+ val keyString = groupingExpressions.mkString("[", ",", "]")
+ val functionString = allAggregateExpressions.mkString("[", ",", "]")
+ val outputString = output.mkString("[", ",", "]")
+ s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)"
+ case Some(fallbackStartsAt) =>
+ s"TungstenAggregateWithControlledFallback $groupingExpressions " +
+ s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt"
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
new file mode 100644
index 0000000000000..af7e0fcedbe4e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -0,0 +1,709 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s.
+ *
+ * This iterator first uses hash-based aggregation to process input rows. It uses
+ * a hash map to store groups and their corresponding aggregation buffers. If we
+ * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * it switches to sort-based aggregation. The process of the switch has the following step:
+ * - Step 1: Sort all entries of the hash map based on values of grouping expressions and
+ * spill them to disk.
+ * - Step 2: Create a external sorter based on the spilled sorted map entries.
+ * - Step 3: Redirect all input rows to the external sorter.
+ * - Step 4: Get a sorted [[KVIterator]] from the external sorter.
+ * - Step 5: Initialize sort-based aggregation.
+ * Then, this iterator works in the way of sort-based aggregation.
+ *
+ * The code of this class is organized as follows:
+ * - Part 1: Initializing aggregate functions.
+ * - Part 2: Methods and fields used by setting aggregation buffer values,
+ * processing input rows from inputIter, and generating output
+ * rows.
+ * - Part 3: Methods and fields used by hash-based aggregation.
+ * - Part 4: Methods and fields used when we switch to sort-based aggregation.
+ * - Part 5: Methods and fields used by sort-based aggregation.
+ * - Part 6: Loads input and process input rows.
+ * - Part 7: Public methods of this iterator.
+ * - Part 8: A utility function used to generate a result when there is no
+ * input and there is no grouping expression.
+ *
+ * @param groupingExpressions
+ * expressions for grouping keys
+ * @param nonCompleteAggregateExpressions
+ * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]],
+ * [[PartialMerge]], or [[Final]].
+ * @param completeAggregateExpressions
+ * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]].
+ * @param initialInputBufferOffset
+ * If this iterator is used to handle functions with mode [[PartialMerge]] or [[Final]].
+ * The input rows have the format of `grouping keys + aggregation buffer`.
+ * This offset indicates the starting position of aggregation buffer in a input row.
+ * @param resultExpressions
+ * expressions for generating output rows.
+ * @param newMutableProjection
+ * the function used to create mutable projections.
+ * @param originalInputAttributes
+ * attributes of representing input rows from `inputIter`.
+ */
+class TungstenAggregationIterator(
+ groupingExpressions: Seq[NamedExpression],
+ nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+ completeAggregateExpressions: Seq[AggregateExpression2],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+ originalInputAttributes: Seq[Attribute],
+ testFallbackStartsAt: Option[Int],
+ numInputRows: LongSQLMetric,
+ numOutputRows: LongSQLMetric)
+ extends Iterator[UnsafeRow] with Logging {
+
+ // The parent partition iterator, to be initialized later in `start`
+ private[this] var inputIter: Iterator[InternalRow] = null
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 1: Initializing aggregate functions.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // A Seq containing all AggregateExpressions.
+ // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final
+ // are at the beginning of the allAggregateExpressions.
+ private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
+ nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+ // Check to make sure we do not have more than three modes in our AggregateExpressions.
+ // If we have, users are hitting a bug and we throw an IllegalStateException.
+ if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
+ throw new IllegalStateException(
+ s"$allAggregateExpressions should have no more than 2 kinds of modes.")
+ }
+
+ //
+ // The modes of AggregateExpressions. Right now, we can handle the following mode:
+ // - Partial-only:
+ // All AggregateExpressions have the mode of Partial.
+ // For this case, aggregationMode is (Some(Partial), None).
+ // - PartialMerge-only:
+ // All AggregateExpressions have the mode of PartialMerge).
+ // For this case, aggregationMode is (Some(PartialMerge), None).
+ // - Final-only:
+ // All AggregateExpressions have the mode of Final.
+ // For this case, aggregationMode is (Some(Final), None).
+ // - Final-Complete:
+ // Some AggregateExpressions have the mode of Final and
+ // others have the mode of Complete. For this case,
+ // aggregationMode is (Some(Final), Some(Complete)).
+ // - Complete-only:
+ // nonCompleteAggregateExpressions is empty and we have AggregateExpressions
+ // with mode Complete in completeAggregateExpressions. For this case,
+ // aggregationMode is (None, Some(Complete)).
+ // - Grouping-only:
+ // There is no AggregateExpression. For this case, AggregationMode is (None,None).
+ //
+ private[this] var aggregationMode: (Option[AggregateMode], Option[AggregateMode]) = {
+ nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+ completeAggregateExpressions.map(_.mode).distinct.headOption
+ }
+
+ // All aggregate functions. TungstenAggregationIterator only handles AlgebraicAggregates.
+ // If there is any functions that is not an AlgebraicAggregate, we throw an
+ // IllegalStateException.
+ private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
+ if (!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])) {
+ throw new IllegalStateException(
+ "Only AlgebraicAggregates should be passed in TungstenAggregationIterator.")
+ }
+
+ allAggregateExpressions
+ .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
+ .toArray
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 2: Methods and fields used by setting aggregation buffer values,
+ // processing input rows from inputIter, and generating output
+ // rows.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // The projection used to initialize buffer values.
+ private[this] val algebraicInitialProjection: MutableProjection = {
+ val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+ newMutableProjection(initExpressions, Nil)()
+ }
+
+ // Creates a new aggregation buffer and initializes buffer values.
+ // This functions should be only called at most three times (when we create the hash map,
+ // when we switch to sort-based aggregation, and when we create the re-used buffer for
+ // sort-based aggregation).
+ private def createNewAggregationBuffer(): UnsafeRow = {
+ val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val bufferRowSize: Int = bufferSchema.length
+
+ val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+ val unsafeProjection =
+ UnsafeProjection.create(bufferSchema.map(_.dataType))
+ val buffer = unsafeProjection.apply(genericMutableBuffer)
+ algebraicInitialProjection.target(buffer)(EmptyRow)
+ buffer
+ }
+
+ // Creates a function used to process a row based on the given inputAttributes.
+ private def generateProcessRow(
+ inputAttributes: Seq[Attribute]): (UnsafeRow, InternalRow) => Unit = {
+
+ val aggregationBufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+ val joinedRow = new JoinedRow()
+
+ aggregationMode match {
+ // Partial-only
+ case (Some(Partial), None) =>
+ val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions)
+ val algebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
+ algebraicUpdateProjection.target(currentBuffer)
+ algebraicUpdateProjection(joinedRow(currentBuffer, row))
+ }
+
+ // PartialMerge-only or Final-only
+ case (Some(PartialMerge), None) | (Some(Final), None) =>
+ val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions)
+ // This projection is used to merge buffer values for all AlgebraicAggregates.
+ val algebraicMergeProjection =
+ newMutableProjection(
+ mergeExpressions,
+ aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
+ // Process all algebraic aggregate functions.
+ algebraicMergeProjection.target(currentBuffer)
+ algebraicMergeProjection(joinedRow(currentBuffer, row))
+ }
+
+ // Final-Complete
+ case (Some(Final), Some(Complete)) =>
+ val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] =
+ allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+ val completeAggregateFunctions: Array[AlgebraicAggregate] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+ val completeOffsetExpressions =
+ Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ val mergeExpressions =
+ nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions
+ val finalAlgebraicMergeProjection =
+ newMutableProjection(
+ mergeExpressions,
+ aggregationBufferAttributes ++ inputAttributes)()
+
+ // We do not touch buffer values of aggregate functions with the Final mode.
+ val finalOffsetExpressions =
+ Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+ val updateExpressions =
+ finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions)
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
+ val input = joinedRow(currentBuffer, row)
+ // For all aggregate functions with mode Complete, update the given currentBuffer.
+ completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+
+ // For all aggregate functions with mode Final, merge buffer values in row to
+ // currentBuffer.
+ finalAlgebraicMergeProjection.target(currentBuffer)(input)
+ }
+
+ // Complete-only
+ case (None, Some(Complete)) =>
+ val completeAggregateFunctions: Array[AlgebraicAggregate] =
+ allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+ val updateExpressions =
+ completeAggregateFunctions.flatMap(_.updateExpressions)
+ val completeAlgebraicUpdateProjection =
+ newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)()
+
+ (currentBuffer: UnsafeRow, row: InternalRow) => {
+ completeAlgebraicUpdateProjection.target(currentBuffer)
+ // For all aggregate functions with mode Complete, update the given currentBuffer.
+ completeAlgebraicUpdateProjection(joinedRow(currentBuffer, row))
+ }
+
+ // Grouping only.
+ case (None, None) => (currentBuffer: UnsafeRow, row: InternalRow) => {}
+
+ case other =>
+ throw new IllegalStateException(
+ s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+ }
+ }
+
+ // Creates a function used to generate output rows.
+ private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow = {
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+
+ aggregationMode match {
+ // Partial-only or PartialMerge-only: every output row is basically the values of
+ // the grouping expressions and the corresponding aggregation buffer.
+ case (Some(Partial), None) | (Some(PartialMerge), None) =>
+ val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ val bufferSchema = StructType.fromAttributes(bufferAttributes)
+ val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+ unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
+ }
+
+ // Final-only, Complete-only and Final-Complete: a output row is generated based on
+ // resultExpressions.
+ case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+ val joinedRow = new JoinedRow()
+ val resultProjection =
+ UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+ resultProjection(joinedRow(currentGroupingKey, currentBuffer))
+ }
+
+ // Grouping-only: a output row is generated from values of grouping expressions.
+ case (None, None) =>
+ val resultProjection =
+ UnsafeProjection.create(resultExpressions, groupingAttributes)
+
+ (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+ resultProjection(currentGroupingKey)
+ }
+
+ case other =>
+ throw new IllegalStateException(
+ s"${aggregationMode} should not be passed into TungstenAggregationIterator.")
+ }
+ }
+
+ // An UnsafeProjection used to extract grouping keys from the input rows.
+ private[this] val groupProjection =
+ UnsafeProjection.create(groupingExpressions, originalInputAttributes)
+
+ // A function used to process a input row. Its first argument is the aggregation buffer
+ // and the second argument is the input row.
+ private[this] var processRow: (UnsafeRow, InternalRow) => Unit =
+ generateProcessRow(originalInputAttributes)
+
+ // A function used to generate output rows based on the grouping keys (first argument)
+ // and the corresponding aggregation buffer (second argument).
+ private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow =
+ generateResultProjection()
+
+ // An aggregation buffer containing initial buffer values. It is used to
+ // initialize other aggregation buffers.
+ private[this] val initialAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 3: Methods and fields used by hash-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // This is the hash map used for hash-based aggregation. It is backed by an
+ // UnsafeFixedWidthAggregationMap and it is used to store
+ // all groups and their corresponding aggregation buffers for hash-based aggregation.
+ private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
+ initialAggregationBuffer,
+ StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+ StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
+ TaskContext.get.taskMemoryManager(),
+ SparkEnv.get.shuffleMemoryManager,
+ 1024 * 16, // initial capacity
+ SparkEnv.get.shuffleMemoryManager.pageSizeBytes,
+ false // disable tracking of performance metrics
+ )
+
+ // Exposed for testing
+ private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
+
+ // The function used to read and process input rows. When processing input rows,
+ // it first uses hash-based aggregation by putting groups and their buffers in
+ // hashMap. If we could not allocate more memory for the map, we switch to
+ // sort-based aggregation (by calling switchToSortBasedAggregation).
+ private def processInputs(): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
+ while (!sortBased && inputIter.hasNext) {
+ val newInput = inputIter.next()
+ numInputRows += 1
+ val groupingKey = groupProjection.apply(newInput)
+ val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
+ if (buffer == null) {
+ // buffer == null means that we could not allocate more memory.
+ // Now, we need to spill the map and switch to sort-based aggregation.
+ switchToSortBasedAggregation(groupingKey, newInput)
+ } else {
+ processRow(buffer, newInput)
+ }
+ }
+ }
+
+ // This function is only used for testing. It basically the same as processInputs except
+ // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
+ // been processed.
+ private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
+ var i = 0
+ while (!sortBased && inputIter.hasNext) {
+ val newInput = inputIter.next()
+ numInputRows += 1
+ val groupingKey = groupProjection.apply(newInput)
+ val buffer: UnsafeRow = if (i < fallbackStartsAt) {
+ hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
+ } else {
+ null
+ }
+ if (buffer == null) {
+ // buffer == null means that we could not allocate more memory.
+ // Now, we need to spill the map and switch to sort-based aggregation.
+ switchToSortBasedAggregation(groupingKey, newInput)
+ } else {
+ processRow(buffer, newInput)
+ }
+ i += 1
+ }
+ }
+
+ // The iterator created from hashMap. It is used to generate output rows when we
+ // are using hash-based aggregation.
+ private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, UnsafeRow] = null
+
+ // Indicates if aggregationBufferMapIterator still has key-value pairs.
+ private[this] var mapIteratorHasNext: Boolean = false
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 4: Methods and fields used when we switch to sort-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // This sorter is used for sort-based aggregation. It is initialized as soon as
+ // we switch from hash-based to sort-based aggregation. Otherwise, it is not used.
+ private[this] var externalSorter: UnsafeKVExternalSorter = null
+
+ /**
+ * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
+ */
+ private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
+ assert(inputIter != null, "attempted to process input when iterator was null")
+ logInfo("falling back to sort based aggregation.")
+ // Step 1: Get the ExternalSorter containing sorted entries of the map.
+ externalSorter = hashMap.destructAndCreateExternalSorter()
+
+ // Step 2: Free the memory used by the map.
+ hashMap.free()
+
+ // Step 3: If we have aggregate function with mode Partial or Complete,
+ // we need to process input rows to get aggregation buffer.
+ // So, later in the sort-based aggregation iterator, we can do merge.
+ // If aggregate functions are with mode Final and PartialMerge,
+ // we just need to project the aggregation buffer from an input row.
+ val needsProcess = aggregationMode match {
+ case (Some(Partial), None) => true
+ case (None, Some(Complete)) => true
+ case (Some(Final), Some(Complete)) => true
+ case _ => false
+ }
+
+ // Note: Since we spill the sorter's contents immediately after creating it, we must insert
+ // something into the sorter here to ensure that we acquire at least a page of memory.
+ // This is done through `externalSorter.insertKV`, which will trigger the page allocation.
+ // Otherwise, children operators may steal the window of opportunity and starve our sorter.
+
+ if (needsProcess) {
+ // First, we create a buffer.
+ val buffer = createNewAggregationBuffer()
+
+ // Process firstKey and firstInput.
+ // Initialize buffer.
+ buffer.copyFrom(initialAggregationBuffer)
+ processRow(buffer, firstInput)
+ externalSorter.insertKV(firstKey, buffer)
+
+ // Process the rest of input rows.
+ while (inputIter.hasNext) {
+ val newInput = inputIter.next()
+ numInputRows += 1
+ val groupingKey = groupProjection.apply(newInput)
+ buffer.copyFrom(initialAggregationBuffer)
+ processRow(buffer, newInput)
+ externalSorter.insertKV(groupingKey, buffer)
+ }
+ } else {
+ // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
+ // We need to project the aggregation buffer part from an input row.
+ val buffer = createNewAggregationBuffer()
+ // The originalInputAttributes are using cloneBufferAttributes. So, we need to use
+ // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+ val bufferExtractor = newMutableProjection(
+ allAggregateFunctions.flatMap(_.cloneBufferAttributes),
+ originalInputAttributes)()
+ bufferExtractor.target(buffer)
+
+ // Insert firstKey and its buffer.
+ bufferExtractor(firstInput)
+ externalSorter.insertKV(firstKey, buffer)
+
+ // Insert the rest of input rows.
+ while (inputIter.hasNext) {
+ val newInput = inputIter.next()
+ numInputRows += 1
+ val groupingKey = groupProjection.apply(newInput)
+ bufferExtractor(newInput)
+ externalSorter.insertKV(groupingKey, buffer)
+ }
+ }
+
+ // Set aggregationMode, processRow, and generateOutput for sort-based aggregation.
+ val newAggregationMode = aggregationMode match {
+ case (Some(Partial), None) => (Some(PartialMerge), None)
+ case (None, Some(Complete)) => (Some(Final), None)
+ case (Some(Final), Some(Complete)) => (Some(Final), None)
+ case other => other
+ }
+ aggregationMode = newAggregationMode
+
+ // Basically the value of the KVIterator returned by externalSorter
+ // will just aggregation buffer. At here, we use cloneBufferAttributes.
+ val newInputAttributes: Seq[Attribute] =
+ allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+
+ // Set up new processRow and generateOutput.
+ processRow = generateProcessRow(newInputAttributes)
+ generateOutput = generateResultProjection()
+
+ // Step 5: Get the sorted iterator from the externalSorter.
+ sortedKVIterator = externalSorter.sortedIterator()
+
+ // Step 6: Pre-load the first key-value pair from the sorted iterator to make
+ // hasNext idempotent.
+ sortedInputHasNewGroup = sortedKVIterator.next()
+
+ // Copy the first key and value (aggregation buffer).
+ if (sortedInputHasNewGroup) {
+ val key = sortedKVIterator.getKey
+ val value = sortedKVIterator.getValue
+ nextGroupingKey = key.copy()
+ currentGroupingKey = key.copy()
+ firstRowInNextGroup = value.copy()
+ }
+
+ // Step 7: set sortBased to true.
+ sortBased = true
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 5: Methods and fields used by sort-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ // Indicates if we are using sort-based aggregation. Because we first try to use
+ // hash-based aggregation, its initial value is false.
+ private[this] var sortBased: Boolean = false
+
+ // The KVIterator containing input rows for the sort-based aggregation. It will be
+ // set in switchToSortBasedAggregation when we switch to sort-based aggregation.
+ private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = null
+
+ // The grouping key of the current group.
+ private[this] var currentGroupingKey: UnsafeRow = null
+
+ // The grouping key of next group.
+ private[this] var nextGroupingKey: UnsafeRow = null
+
+ // The first row of next group.
+ private[this] var firstRowInNextGroup: UnsafeRow = null
+
+ // Indicates if we has new group of rows from the sorted input iterator.
+ private[this] var sortedInputHasNewGroup: Boolean = false
+
+ // The aggregation buffer used by the sort-based aggregation.
+ private[this] val sortBasedAggregationBuffer: UnsafeRow = createNewAggregationBuffer()
+
+ // Processes rows in the current group. It will stop when it find a new group.
+ private def processCurrentSortedGroup(): Unit = {
+ // First, we need to copy nextGroupingKey to currentGroupingKey.
+ currentGroupingKey.copyFrom(nextGroupingKey)
+ // Now, we will start to find all rows belonging to this group.
+ // We create a variable to track if we see the next group.
+ var findNextPartition = false
+ // firstRowInNextGroup is the first row of this group. We first process it.
+ processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+ // The search will stop when we see the next group or there is no
+ // input row left in the iter.
+ // Pre-load the first key-value pair to make the condition of the while loop
+ // has no action (we do not trigger loading a new key-value pair
+ // when we evaluate the condition).
+ var hasNext = sortedKVIterator.next()
+ while (!findNextPartition && hasNext) {
+ // Get the grouping key and value (aggregation buffer).
+ val groupingKey = sortedKVIterator.getKey
+ val inputAggregationBuffer = sortedKVIterator.getValue
+
+ // Check if the current row belongs the current input row.
+ if (currentGroupingKey.equals(groupingKey)) {
+ processRow(sortBasedAggregationBuffer, inputAggregationBuffer)
+
+ hasNext = sortedKVIterator.next()
+ } else {
+ // We find a new group.
+ findNextPartition = true
+ // copyFrom will fail when
+ nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy()
+ firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = inputAggregationBuffer.copy()
+
+ }
+ }
+ // We have not seen a new group. It means that there is no new row in the input
+ // iter. The current group is the last group of the sortedKVIterator.
+ if (!findNextPartition) {
+ sortedInputHasNewGroup = false
+ sortedKVIterator.close()
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 6: Loads input rows and setup aggregationBufferMapIterator if we
+ // have not switched to sort-based aggregation.
+ ///////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Start processing input rows.
+ * Only after this method is called will this iterator be non-empty.
+ */
+ def start(parentIter: Iterator[InternalRow]): Unit = {
+ inputIter = parentIter
+ testFallbackStartsAt match {
+ case None =>
+ processInputs()
+ case Some(fallbackStartsAt) =>
+ // This is the testing path. processInputsWithControlledFallback is same as processInputs
+ // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+ // have been processed.
+ processInputsWithControlledFallback(fallbackStartsAt)
+ }
+
+ // If we did not switch to sort-based aggregation in processInputs,
+ // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+ if (!sortBased) {
+ // First, set aggregationBufferMapIterator.
+ aggregationBufferMapIterator = hashMap.iterator()
+ // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+ // If the map is empty, we just free it.
+ if (!mapIteratorHasNext) {
+ hashMap.free()
+ }
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 7: Iterator's public methods.
+ ///////////////////////////////////////////////////////////////////////////
+
+ override final def hasNext: Boolean = {
+ (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
+ }
+
+ override final def next(): UnsafeRow = {
+ if (hasNext) {
+ val res = if (sortBased) {
+ // Process the current group.
+ processCurrentSortedGroup()
+ // Generate output row for the current group.
+ val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
+ // Initialize buffer values for the next group.
+ sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+
+ outputRow
+ } else {
+ // We did not fall back to sort-based aggregation.
+ val result =
+ generateOutput(
+ aggregationBufferMapIterator.getKey,
+ aggregationBufferMapIterator.getValue)
+
+ // Pre-load next key-value pair form aggregationBufferMapIterator to make hasNext
+ // idempotent.
+ mapIteratorHasNext = aggregationBufferMapIterator.next()
+
+ if (!mapIteratorHasNext) {
+ // If there is no input from aggregationBufferMapIterator, we copy current result.
+ val resultCopy = result.copy()
+ // Then, we free the map.
+ hashMap.free()
+
+ resultCopy
+ } else {
+ result
+ }
+ }
+
+ // If this is the last record, update the task's peak memory usage. Since we destroy
+ // the map to create the sorter, their memory usages should not overlap, so it is safe
+ // to just use the max of the two.
+ if (!hasNext) {
+ val mapMemory = hashMap.getPeakMemoryUsedBytes
+ val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
+ val peakMemory = Math.max(mapMemory, sorterMemory)
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
+ }
+ numOutputRows += 1
+ res
+ } else {
+ // no more result
+ throw new NoSuchElementException
+ }
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ // Part 8: Utility functions
+ ///////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Generate a output row when there is no input and there is no grouping expression.
+ */
+ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
+ assert(groupingExpressions.isEmpty)
+ assert(inputIter == null)
+ generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
+ }
+
+ /** Free memory used in the underlying map. */
+ def free(): Unit = {
+ hashMap.free()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
deleted file mode 100644
index 98538c462bc89..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
-
-case class Aggregate2Sort(
- requiredChildDistributionExpressions: Option[Seq[Expression]],
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- aggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
-
- override def canProcessUnsafeRows: Boolean = true
-
- override def references: AttributeSet = {
- val referencesInResults =
- AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
-
- AttributeSet(
- groupingExpressions.flatMap(_.references) ++
- aggregateExpressions.flatMap(_.references) ++
- referencesInResults)
- }
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
- // TODO: We should not sort the input rows if they are just in reversed order.
- groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
- }
-
- override def outputOrdering: Seq[SortOrder] = {
- // It is possible that the child.outputOrdering starts with the required
- // ordering expressions (e.g. we require [a] as the sort expression and the
- // child's outputOrdering is [a, b]). We can only guarantee the output rows
- // are sorted by values of groupingExpressions.
- groupingExpressions.map(SortOrder(_, Ascending))
- }
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- child.execute().mapPartitions { iter =>
- if (aggregateExpressions.length == 0) {
- new FinalSortAggregationIterator(
- groupingExpressions,
- Nil,
- Nil,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter)
- } else {
- val aggregationIterator: SortAggregationIterator = {
- aggregateExpressions.map(_.mode).distinct.toList match {
- case Partial :: Nil =>
- new PartialSortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- child.output,
- iter)
- case PartialMerge :: Nil =>
- new PartialMergeSortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- child.output,
- iter)
- case Final :: Nil =>
- new FinalSortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- aggregateAttributes,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter)
- case other =>
- sys.error(
- s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
- s"modes $other in this operator.")
- }
- }
-
- aggregationIterator
- }
- }
- }
-}
-
-case class FinalAndCompleteAggregate2Sort(
- previousGroupingExpressions: Seq[NamedExpression],
- groupingExpressions: Seq[NamedExpression],
- finalAggregateExpressions: Seq[AggregateExpression2],
- finalAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
- override def references: AttributeSet = {
- val referencesInResults =
- AttributeSet(resultExpressions.flatMap(_.references)) --
- AttributeSet(finalAggregateExpressions) --
- AttributeSet(completeAggregateExpressions)
-
- AttributeSet(
- groupingExpressions.flatMap(_.references) ++
- finalAggregateExpressions.flatMap(_.references) ++
- completeAggregateExpressions.flatMap(_.references) ++
- referencesInResults)
- }
-
- override def requiredChildDistribution: List[Distribution] = {
- if (groupingExpressions.isEmpty) {
- AllTuples :: Nil
- } else {
- ClusteredDistribution(groupingExpressions) :: Nil
- }
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- child.execute().mapPartitions { iter =>
-
- new FinalAndCompleteSortAggregationIterator(
- previousGroupingExpressions.length,
- groupingExpressions,
- finalAggregateExpressions,
- finalAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- resultExpressions,
- newMutableProjection,
- child.output,
- iter)
- }
- }
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
deleted file mode 100644
index 2ca0cb82c1aab..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ /dev/null
@@ -1,664 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.types.NullType
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * An iterator used to evaluate aggregate functions. It assumes that input rows
- * are already grouped by values of `groupingExpressions`.
- */
-private[sql] abstract class SortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends Iterator[InternalRow] {
-
- ///////////////////////////////////////////////////////////////////////////
- // Static fields for this iterator
- ///////////////////////////////////////////////////////////////////////////
-
- protected val aggregateFunctions: Array[AggregateFunction2] = {
- var mutableBufferOffset = 0
- var inputBufferOffset: Int = initialInputBufferOffset
- val functions = new Array[AggregateFunction2](aggregateExpressions.length)
- var i = 0
- while (i < aggregateExpressions.length) {
- val func = aggregateExpressions(i).aggregateFunction
- val funcWithBoundReferences = aggregateExpressions(i).mode match {
- case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
- // We need to create BoundReferences if the function is not an
- // AlgebraicAggregate (it does not support code-gen) and the mode of
- // this function is Partial or Complete because we will call eval of this
- // function's children in the update method of this aggregate function.
- // Those eval calls require BoundReferences to work.
- BindReferences.bindReference(func, inputAttributes)
- case _ =>
- // We only need to set inputBufferOffset for aggregate functions with mode
- // PartialMerge and Final.
- func.inputBufferOffset = inputBufferOffset
- inputBufferOffset += func.bufferSchema.length
- func
- }
- // Set mutableBufferOffset for this function. It is important that setting
- // mutableBufferOffset happens after all potential bindReference operations
- // because bindReference will create a new instance of the function.
- funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
- mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
- functions(i) = funcWithBoundReferences
- i += 1
- }
- functions
- }
-
- // Positions of those non-algebraic aggregate functions in aggregateFunctions.
- // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
- // func2 and func3 are non-algebraic aggregate functions.
- // nonAlgebraicAggregateFunctionPositions will be [1, 2].
- protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
- val positions = new ArrayBuffer[Int]()
- var i = 0
- while (i < aggregateFunctions.length) {
- aggregateFunctions(i) match {
- case agg: AlgebraicAggregate =>
- case _ => positions += i
- }
- i += 1
- }
- positions.toArray
- }
-
- // All non-algebraic aggregate functions.
- protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
-
- // This is used to project expressions for the grouping expressions.
- protected val groupGenerator =
- newMutableProjection(groupingExpressions, inputAttributes)()
-
- // The underlying buffer shared by all aggregate functions.
- protected val buffer: MutableRow = {
- // The number of elements of the underlying buffer of this operator.
- // All aggregate functions are sharing this underlying buffer and they find their
- // buffer values through bufferOffset.
- // var size = 0
- // var i = 0
- // while (i < aggregateFunctions.length) {
- // size += aggregateFunctions(i).bufferSchema.length
- // i += 1
- // }
- new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum)
- }
-
- protected val joinedRow = new JoinedRow
-
- // This projection is used to initialize buffer values for all AlgebraicAggregates.
- protected val algebraicInitialProjection = {
- val initExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.initialValues
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
-
- newMutableProjection(initExpressions, Nil)().target(buffer)
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Mutable states
- ///////////////////////////////////////////////////////////////////////////
-
- // The partition key of the current partition.
- protected var currentGroupingKey: InternalRow = _
- // The partition key of next partition.
- protected var nextGroupingKey: InternalRow = _
- // The first row of next partition.
- protected var firstRowInNextGroup: InternalRow = _
- // Indicates if we has new group of rows to process.
- protected var hasNewGroup: Boolean = true
-
- /** Initializes buffer values for all aggregate functions. */
- protected def initializeBuffer(): Unit = {
- algebraicInitialProjection(EmptyRow)
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).initialize(buffer)
- i += 1
- }
- }
-
- protected def initialize(): Unit = {
- if (inputIter.hasNext) {
- initializeBuffer()
- val currentRow = inputIter.next().copy()
- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
- // we are making a copy at here.
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- // This iter is an empty one.
- hasNewGroup = false
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Private methods
- ///////////////////////////////////////////////////////////////////////////
-
- /** Processes rows in the current group. It will stop when it find a new group. */
- private def processCurrentGroup(): Unit = {
- currentGroupingKey = nextGroupingKey
- // Now, we will start to find all rows belonging to this group.
- // We create a variable to track if we see the next group.
- var findNextPartition = false
- // firstRowInNextGroup is the first row of this group. We first process it.
- processRow(firstRowInNextGroup)
- // The search will stop when we see the next group or there is no
- // input row left in the iter.
- while (inputIter.hasNext && !findNextPartition) {
- val currentRow = inputIter.next()
- // Get the grouping key based on the grouping expressions.
- // For the below compare method, we do not need to make a copy of groupingKey.
- val groupingKey = groupGenerator(currentRow)
- // Check if the current row belongs the current input row.
- if (currentGroupingKey == groupingKey) {
- processRow(currentRow)
- } else {
- // We find a new group.
- findNextPartition = true
- nextGroupingKey = groupingKey.copy()
- firstRowInNextGroup = currentRow.copy()
- }
- }
- // We have not seen a new group. It means that there is no new row in the input
- // iter. The current group is the last group of the iter.
- if (!findNextPartition) {
- hasNewGroup = false
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Public methods
- ///////////////////////////////////////////////////////////////////////////
-
- override final def hasNext: Boolean = hasNewGroup
-
- override final def next(): InternalRow = {
- if (hasNext) {
- // Process the current group.
- processCurrentGroup()
- // Generate output row for the current group.
- val outputRow = generateOutput()
- // Initilize buffer values for the next group.
- initializeBuffer()
-
- outputRow
- } else {
- // no more result
- throw new NoSuchElementException
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////
- // Methods that need to be implemented
- ///////////////////////////////////////////////////////////////////////////
-
- /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */
- protected def initialInputBufferOffset: Int
-
- /** The function used to process an input row. */
- protected def processRow(row: InternalRow): Unit
-
- /** The function used to generate the result row. */
- protected def generateOutput(): InternalRow
-
- ///////////////////////////////////////////////////////////////////////////
- // Initialize this iterator
- ///////////////////////////////////////////////////////////////////////////
-
- initialize()
-}
-
-/**
- * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- * The format of its output rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- */
-class PartialSortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // This projection is used to update buffer values for all AlgebraicAggregates.
- private val algebraicUpdateProjection = {
- val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes)
- val updateExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
- newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
- }
-
- override protected def initialInputBufferOffset: Int = 0
-
- override protected def processRow(row: InternalRow): Unit = {
- // Process all algebraic aggregate functions.
- algebraicUpdateProjection(joinedRow(buffer, row))
- // Process all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).update(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // We just output the grouping expressions and the underlying buffer.
- joinedRow(currentGroupingKey, buffer).copy()
- }
-}
-
-/**
- * An iterator used to do partial merge aggregations (for those aggregate functions with mode
- * PartialMerge). It assumes that input rows are already grouped by values of
- * `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its output rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- */
-class PartialMergeSortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // This projection is used to merge buffer values for all AlgebraicAggregates.
- private val algebraicMergeProjection = {
- val mergeInputSchema =
- aggregateFunctions.flatMap(_.bufferAttributes) ++
- groupingExpressions.map(_.toAttribute) ++
- aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
-
- newMutableProjection(mergeExpressions, mergeInputSchema)()
- }
-
- override protected def initialInputBufferOffset: Int = groupingExpressions.length
-
- override protected def processRow(row: InternalRow): Unit = {
- // Process all algebraic aggregate functions.
- algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
- // Process all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).merge(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // We output grouping expressions and aggregation buffers.
- joinedRow(currentGroupingKey, buffer).copy()
- }
-}
-
-/**
- * An iterator used to do final aggregations (for those aggregate functions with mode
- * Final). It assumes that input rows are already grouped by values of
- * `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBufferN|
- *
- * The format of its output rows is represented by the schema of `resultExpressions`.
- */
-class FinalSortAggregationIterator(
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression2],
- aggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- aggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // The result of aggregate functions.
- private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
-
- // The projection used to generate the output rows of this operator.
- // This is only used when we are generating final results of aggregate functions.
- private val resultProjection =
- newMutableProjection(
- resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
-
- // This projection is used to merge buffer values for all AlgebraicAggregates.
- private val algebraicMergeProjection = {
- val mergeInputSchema =
- aggregateFunctions.flatMap(_.bufferAttributes) ++
- groupingExpressions.map(_.toAttribute) ++
- aggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
-
- newMutableProjection(mergeExpressions, mergeInputSchema)()
- }
-
- // This projection is used to evaluate all AlgebraicAggregates.
- private val algebraicEvalProjection = {
- val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
- val evalExpressions = aggregateFunctions.map {
- case ae: AlgebraicAggregate => ae.evaluateExpression
- case agg: AggregateFunction2 => NoOp
- }
-
- newMutableProjection(evalExpressions, bufferSchemata)()
- }
-
- override protected def initialInputBufferOffset: Int = groupingExpressions.length
-
- override def initialize(): Unit = {
- if (inputIter.hasNext) {
- initializeBuffer()
- val currentRow = inputIter.next().copy()
- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
- // we are making a copy at here.
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- if (groupingExpressions.isEmpty) {
- // If there is no grouping expression, we need to generate a single row as the output.
- initializeBuffer()
- // Right now, the buffer only contains initial buffer values. Because
- // merging two buffers with initial values will generate a row that
- // still store initial values. We set the currentRow as the copy of the current buffer.
- // Because input aggregation buffer has initialInputBufferOffset extra values at the
- // beginning, we create a dummy row for this part.
- val currentRow =
- joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- // This iter is an empty one.
- hasNewGroup = false
- }
- }
- }
-
- override protected def processRow(row: InternalRow): Unit = {
- // Process all algebraic aggregate functions.
- algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
- // Process all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- nonAlgebraicAggregateFunctions(i).merge(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // Generate results for all algebraic aggregate functions.
- algebraicEvalProjection.target(aggregateResult)(buffer)
- // Generate results for all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- aggregateResult.update(
- nonAlgebraicAggregateFunctionPositions(i),
- nonAlgebraicAggregateFunctions(i).eval(buffer))
- i += 1
- }
- resultProjection(joinedRow(currentGroupingKey, aggregateResult))
- }
-}
-
-/**
- * An iterator used to do both final aggregations (for those aggregate functions with mode
- * Final) and complete aggregations (for those aggregate functions with mode Complete).
- * It assumes that input rows are already grouped by values of `groupingExpressions`.
- * The format of its input rows is:
- * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
- * col1 to colM are columns used by aggregate functions with Complete mode.
- * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
- * Final mode.
- *
- * The format of its internal buffer is:
- * |aggregationBuffer1|...|aggregationBuffer(N+M)|
- * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
- * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
- * Complete.
- *
- * The format of its output rows is represented by the schema of `resultExpressions`.
- */
-class FinalAndCompleteSortAggregationIterator(
- override protected val initialInputBufferOffset: Int,
- groupingExpressions: Seq[NamedExpression],
- finalAggregateExpressions: Seq[AggregateExpression2],
- finalAggregateAttributes: Seq[Attribute],
- completeAggregateExpressions: Seq[AggregateExpression2],
- completeAggregateAttributes: Seq[Attribute],
- resultExpressions: Seq[NamedExpression],
- newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
- inputAttributes: Seq[Attribute],
- inputIter: Iterator[InternalRow])
- extends SortAggregationIterator(
- groupingExpressions,
- // TODO: document the ordering
- finalAggregateExpressions ++ completeAggregateExpressions,
- newMutableProjection,
- inputAttributes,
- inputIter) {
-
- // The result of aggregate functions.
- private val aggregateResult: MutableRow =
- new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
-
- // The projection used to generate the output rows of this operator.
- // This is only used when we are generating final results of aggregate functions.
- private val resultProjection = {
- val inputSchema =
- groupingExpressions.map(_.toAttribute) ++
- finalAggregateAttributes ++
- completeAggregateAttributes
- newMutableProjection(resultExpressions, inputSchema)()
- }
-
- // All aggregate functions with mode Final.
- private val finalAggregateFunctions: Array[AggregateFunction2] = {
- val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
- var i = 0
- while (i < finalAggregateExpressions.length) {
- functions(i) = aggregateFunctions(i)
- i += 1
- }
- functions
- }
-
- // All non-algebraic aggregate functions with mode Final.
- private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- finalAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
-
- // All aggregate functions with mode Complete.
- private val completeAggregateFunctions: Array[AggregateFunction2] = {
- val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
- var i = 0
- while (i < completeAggregateExpressions.length) {
- functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
- i += 1
- }
- functions
- }
-
- // All non-algebraic aggregate functions with mode Complete.
- private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
- completeAggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }
-
- // This projection is used to merge buffer values for all AlgebraicAggregates with mode
- // Final.
- private val finalAlgebraicMergeProjection = {
- // The first initialInputBufferOffset values of the input aggregation buffer is
- // for grouping expressions and distinct columns.
- val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset)
-
- val completeOffsetExpressions =
- Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
-
- val mergeInputSchema =
- finalAggregateFunctions.flatMap(_.bufferAttributes) ++
- completeAggregateFunctions.flatMap(_.bufferAttributes) ++
- groupingAttributesAndDistinctColumns ++
- finalAggregateFunctions.flatMap(_.cloneBufferAttributes)
- val mergeExpressions =
- finalAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.mergeExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- } ++ completeOffsetExpressions
- newMutableProjection(mergeExpressions, mergeInputSchema)()
- }
-
- // This projection is used to update buffer values for all AlgebraicAggregates with mode
- // Complete.
- private val completeAlgebraicUpdateProjection = {
- // We do not touch buffer values of aggregate functions with the Final mode.
- val finalOffsetExpressions =
- Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
-
- val bufferSchema =
- finalAggregateFunctions.flatMap(_.bufferAttributes) ++
- completeAggregateFunctions.flatMap(_.bufferAttributes)
- val updateExpressions =
- finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.updateExpressions
- case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
- }
- newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
- }
-
- // This projection is used to evaluate all AlgebraicAggregates.
- private val algebraicEvalProjection = {
- val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
- val evalExpressions = aggregateFunctions.map {
- case ae: AlgebraicAggregate => ae.evaluateExpression
- case agg: AggregateFunction2 => NoOp
- }
-
- newMutableProjection(evalExpressions, bufferSchemata)()
- }
-
- override def initialize(): Unit = {
- if (inputIter.hasNext) {
- initializeBuffer()
- val currentRow = inputIter.next().copy()
- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
- // we are making a copy at here.
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- if (groupingExpressions.isEmpty) {
- // If there is no grouping expression, we need to generate a single row as the output.
- initializeBuffer()
- // Right now, the buffer only contains initial buffer values. Because
- // merging two buffers with initial values will generate a row that
- // still store initial values. We set the currentRow as the copy of the current buffer.
- // Because input aggregation buffer has initialInputBufferOffset extra values at the
- // beginning, we create a dummy row for this part.
- val currentRow =
- joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy()
- nextGroupingKey = groupGenerator(currentRow).copy()
- firstRowInNextGroup = currentRow
- } else {
- // This iter is an empty one.
- hasNewGroup = false
- }
- }
- }
-
- override protected def processRow(row: InternalRow): Unit = {
- val input = joinedRow(buffer, row)
- // For all aggregate functions with mode Complete, update buffers.
- completeAlgebraicUpdateProjection(input)
- var i = 0
- while (i < completeNonAlgebraicAggregateFunctions.length) {
- completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
- i += 1
- }
-
- // For all aggregate functions with mode Final, merge buffers.
- finalAlgebraicMergeProjection.target(buffer)(input)
- i = 0
- while (i < finalNonAlgebraicAggregateFunctions.length) {
- finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
- i += 1
- }
- }
-
- override protected def generateOutput(): InternalRow = {
- // Generate results for all algebraic aggregate functions.
- algebraicEvalProjection.target(aggregateResult)(buffer)
- // Generate results for all non-algebraic aggregate functions.
- var i = 0
- while (i < nonAlgebraicAggregateFunctions.length) {
- aggregateResult.update(
- nonAlgebraicAggregateFunctionPositions(i),
- nonAlgebraicAggregateFunctions(i).eval(buffer))
- i += 1
- }
-
- resultProjection(joinedRow(currentGroupingKey, aggregateResult))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index cc54319171bdb..7619f3ec9f0a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -24,7 +24,154 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti
import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
-import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType}
+import org.apache.spark.sql.types._
+
+/**
+ * A helper trait used to create specialized setter and getter for types supported by
+ * [[org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap]]'s buffer.
+ * (see UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema).
+ */
+sealed trait BufferSetterGetterUtils {
+
+ def createGetters(schema: StructType): Array[(InternalRow, Int) => Any] = {
+ val dataTypes = schema.fields.map(_.dataType)
+ val getters = new Array[(InternalRow, Int) => Any](dataTypes.length)
+
+ var i = 0
+ while (i < getters.length) {
+ getters(i) = dataTypes(i) match {
+ case BooleanType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
+
+ case ByteType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getByte(ordinal)
+
+ case ShortType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getShort(ordinal)
+
+ case IntegerType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+ case LongType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
+ case FloatType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getFloat(ordinal)
+
+ case DoubleType =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getDouble(ordinal)
+
+ case dt: DecimalType =>
+ val precision = dt.precision
+ val scale = dt.scale
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
+
+ case other =>
+ (row: InternalRow, ordinal: Int) =>
+ if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
+ }
+
+ i += 1
+ }
+
+ getters
+ }
+
+ def createSetters(schema: StructType): Array[((MutableRow, Int, Any) => Unit)] = {
+ val dataTypes = schema.fields.map(_.dataType)
+ val setters = new Array[(MutableRow, Int, Any) => Unit](dataTypes.length)
+
+ var i = 0
+ while (i < setters.length) {
+ setters(i) = dataTypes(i) match {
+ case b: BooleanType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setBoolean(ordinal, value.asInstanceOf[Boolean])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case ByteType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setByte(ordinal, value.asInstanceOf[Byte])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case ShortType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setShort(ordinal, value.asInstanceOf[Short])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case IntegerType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setInt(ordinal, value.asInstanceOf[Int])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case LongType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setLong(ordinal, value.asInstanceOf[Long])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case FloatType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setFloat(ordinal, value.asInstanceOf[Float])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case DoubleType =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setDouble(ordinal, value.asInstanceOf[Double])
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case dt: DecimalType =>
+ val precision = dt.precision
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+ } else {
+ row.setNullAt(ordinal)
+ }
+
+ case other =>
+ (row: MutableRow, ordinal: Int, value: Any) =>
+ if (value != null) {
+ row.update(ordinal, value)
+ } else {
+ row.setNullAt(ordinal)
+ }
+ }
+
+ i += 1
+ }
+
+ setters
+ }
+}
/**
* A Mutable [[Row]] representing an mutable aggregation buffer.
@@ -35,7 +182,7 @@ private[sql] class MutableAggregationBufferImpl (
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingBuffer: MutableRow)
- extends MutableAggregationBuffer {
+ extends MutableAggregationBuffer with BufferSetterGetterUtils {
private[this] val offsets: Array[Int] = {
val newOffsets = new Array[Int](length)
@@ -47,6 +194,10 @@ private[sql] class MutableAggregationBufferImpl (
newOffsets
}
+ private[this] val bufferValueGetters = createGetters(schema)
+
+ private[this] val bufferValueSetters = createSetters(schema)
+
override def length: Int = toCatalystConverters.length
override def get(i: Int): Any = {
@@ -54,7 +205,7 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
- toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
+ toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
}
def update(i: Int, value: Any): Unit = {
@@ -62,7 +213,15 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not update ${i}th value in this buffer because it only has $length values.")
}
- underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+
+ bufferValueSetters(i)(underlyingBuffer, offsets(i), toCatalystConverters(i)(value))
+ }
+
+ // Because get method call specialized getter based on the schema, we cannot use the
+ // default implementation of the isNullAt (which is get(i) == null).
+ // We have to override it to call isNullAt of the underlyingBuffer.
+ override def isNullAt(i: Int): Boolean = {
+ underlyingBuffer.isNullAt(offsets(i))
}
override def copy(): MutableAggregationBufferImpl = {
@@ -84,7 +243,7 @@ private[sql] class InputAggregationBuffer private[sql] (
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingInputBuffer: InternalRow)
- extends Row {
+ extends Row with BufferSetterGetterUtils {
private[this] val offsets: Array[Int] = {
val newOffsets = new Array[Int](length)
@@ -96,6 +255,10 @@ private[sql] class InputAggregationBuffer private[sql] (
newOffsets
}
+ private[this] val bufferValueGetters = createGetters(schema)
+
+ def getBufferOffset: Int = bufferOffset
+
override def length: Int = toCatalystConverters.length
override def get(i: Int): Any = {
@@ -103,8 +266,14 @@ private[sql] class InputAggregationBuffer private[sql] (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
- // TODO: Use buffer schema to avoid using generic getter.
- toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType))
+ toScalaConverters(i)(bufferValueGetters(i)(underlyingInputBuffer, offsets(i)))
+ }
+
+ // Because get method call specialized getter based on the schema, we cannot use the
+ // default implementation of the isNullAt (which is get(i) == null).
+ // We have to override it to call isNullAt of the underlyingInputBuffer.
+ override def isNullAt(i: Int): Boolean = {
+ underlyingInputBuffer.isNullAt(offsets(i))
}
override def copy(): InputAggregationBuffer = {
@@ -147,7 +316,7 @@ private[sql] case class ScalaUDAF(
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
- val childrenSchema: StructType = {
+ private[this] lazy val childrenSchema: StructType = {
val inputFields = children.zipWithIndex.map {
case (child, index) =>
StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
@@ -155,7 +324,7 @@ private[sql] case class ScalaUDAF(
StructType(inputFields)
}
- lazy val inputProjection = {
+ private lazy val inputProjection = {
val inputAttributes = childrenSchema.toAttributes
log.debug(
s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
@@ -168,40 +337,68 @@ private[sql] case class ScalaUDAF(
}
}
- val inputToScalaConverters: Any => Any =
+ private[this] lazy val inputToScalaConverters: Any => Any =
CatalystTypeConverters.createToScalaConverter(childrenSchema)
- val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
- CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ private[this] lazy val bufferValuesToCatalystConverters: Array[Any => Any] = {
+ bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToCatalystConverter(field.dataType)
+ }
}
- val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
- CatalystTypeConverters.createToScalaConverter(field.dataType)
+ private[this] lazy val bufferValuesToScalaConverters: Array[Any => Any] = {
+ bufferSchema.fields.map { field =>
+ CatalystTypeConverters.createToScalaConverter(field.dataType)
+ }
}
- lazy val inputAggregateBuffer: InputAggregationBuffer =
- new InputAggregationBuffer(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- inputBufferOffset,
- null)
-
- lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
- new MutableAggregationBufferImpl(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableBufferOffset,
- null)
+ // This buffer is only used at executor side.
+ private[this] var inputAggregateBuffer: InputAggregationBuffer = null
+
+ // This buffer is only used at executor side.
+ private[this] var mutableAggregateBuffer: MutableAggregationBufferImpl = null
+
+ // This buffer is only used at executor side.
+ private[this] var evalAggregateBuffer: InputAggregationBuffer = null
+
+ /**
+ * Sets the inputBufferOffset to newInputBufferOffset and then create a new instance of
+ * `inputAggregateBuffer` based on this new inputBufferOffset.
+ */
+ override def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
+ super.withNewInputBufferOffset(newInputBufferOffset)
+ // inputBufferOffset has been updated.
+ inputAggregateBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ inputBufferOffset,
+ null)
+ }
- lazy val evalAggregateBuffer: InputAggregationBuffer =
- new InputAggregationBuffer(
- bufferSchema,
- bufferValuesToCatalystConverters,
- bufferValuesToScalaConverters,
- mutableBufferOffset,
- null)
+ /**
+ * Sets the mutableBufferOffset to newMutableBufferOffset and then create a new instance of
+ * `mutableAggregateBuffer` and `evalAggregateBuffer` based on this new mutableBufferOffset.
+ */
+ override def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
+ super.withNewMutableBufferOffset(newMutableBufferOffset)
+ // mutableBufferOffset has been updated.
+ mutableAggregateBuffer =
+ new MutableAggregationBufferImpl(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableBufferOffset,
+ null)
+ evalAggregateBuffer =
+ new InputAggregationBuffer(
+ bufferSchema,
+ bufferValuesToCatalystConverters,
+ bufferValuesToScalaConverters,
+ mutableBufferOffset,
+ null)
+ }
override def initialize(buffer: MutableRow): Unit = {
mutableAggregateBuffer.underlyingBuffer = buffer
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 03635baae4a5f..80816a095ea8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,24 +17,41 @@
package org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst._
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
+import org.apache.spark.sql.types.StructType
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
*/
object Utils {
+ def supportsTungstenAggregate(
+ groupingExpressions: Seq[Expression],
+ aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+ val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeProjection.canSupport(groupingExpressions)
+ }
+
def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
+ // Check if we can use TungstenAggregate.
+ val usesTungstenAggregate =
+ child.sqlContext.conf.unsafeEnabled &&
+ aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+ supportsTungstenAggregate(
+ groupingExpressions,
+ aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
+
// 1. Create an Aggregate Operator for partial aggregations.
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
@@ -48,43 +65,91 @@ object Utils {
val groupExpressionMap = namedGroupingExpressions.toMap
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
- val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
- agg.aggregateFunction.bufferAttributes
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ val partialResultExpressions =
+ namedGroupingAttributes ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+
+ val partialAggregate = if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = namedGroupingExpressions.map(_._2),
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialResultExpressions,
+ child = child)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = namedGroupingExpressions.map(_._2),
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialResultExpressions,
+ child = child)
}
- val partialAggregate =
- Aggregate2Sort(
- None: Option[Seq[Expression]],
- namedGroupingExpressions.map(_._2),
- partialAggregateExpressions,
- partialAggregateAttributes,
- namedGroupingAttributes ++ partialAggregateAttributes,
- child)
// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
val finalAggregateAttributes =
finalAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
+ }
+
+ val finalAggregate = if (usesTungstenAggregate) {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ // aggregateFunctionMap contains unique aggregate functions.
+ val aggregateFunction =
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
+ aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = namedGroupingAttributes.length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialAggregate)
+ } else {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
}
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
+
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = namedGroupingAttributes.length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialAggregate)
}
- val finalAggregate = Aggregate2Sort(
- Some(namedGroupingAttributes),
- namedGroupingAttributes,
- finalAggregateExpressions,
- finalAggregateAttributes,
- rewrittenResultExpressions,
- partialAggregate)
finalAggregate :: Nil
}
@@ -93,10 +158,18 @@ object Utils {
groupingExpressions: Seq[Expression],
functionsWithDistinct: Seq[AggregateExpression2],
functionsWithoutDistinct: Seq[AggregateExpression2],
- aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+ aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
+ val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
+ val usesTungstenAggregate =
+ child.sqlContext.conf.unsafeEnabled &&
+ aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+ supportsTungstenAggregate(
+ groupingExpressions,
+ aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
// 1. Create an Aggregate Operator for partial aggregations.
// The grouping expressions are original groupingExpressions and
// distinct columns. For example, for avg(distinct value) ... group by key
@@ -126,88 +199,160 @@ object Utils {
val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
- val partialAggregateExpressions = functionsWithoutDistinct.map {
- case AggregateExpression2(aggregateFunction, mode, _) =>
- AggregateExpression2(aggregateFunction, Partial, false)
+ val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val partialAggregateAttributes =
+ partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ val partialAggregateGroupingExpressions =
+ (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
+ val partialAggregateResult =
+ namedGroupingAttributes ++
+ distinctColumnAttributes ++
+ partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ val partialAggregate = if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+ groupingExpressions = partialAggregateGroupingExpressions,
+ nonCompleteAggregateExpressions = partialAggregateExpressions,
+ nonCompleteAggregateAttributes = partialAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = 0,
+ resultExpressions = partialAggregateResult,
+ child = child)
}
- val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
- agg.aggregateFunction.bufferAttributes
- }
- val partialAggregate =
- Aggregate2Sort(
- None: Option[Seq[Expression]],
- (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
- partialAggregateExpressions,
- partialAggregateAttributes,
- namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
- child)
// 2. Create an Aggregate Operator for partial merge aggregations.
- val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
- case AggregateExpression2(aggregateFunction, mode, _) =>
- AggregateExpression2(aggregateFunction, PartialMerge, false)
- }
+ val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
val partialMergeAggregateAttributes =
- partialMergeAggregateExpressions.flatMap { agg =>
- agg.aggregateFunction.bufferAttributes
- }
- val partialMergeAggregate =
- Aggregate2Sort(
- Some(namedGroupingAttributes),
- namedGroupingAttributes ++ distinctColumnAttributes,
- partialMergeAggregateExpressions,
- partialMergeAggregateAttributes,
- namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
- partialAggregate)
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+ val partialMergeAggregateResult =
+ namedGroupingAttributes ++
+ distinctColumnAttributes ++
+ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+ val partialMergeAggregate = if (usesTungstenAggregate) {
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ completeAggregateExpressions = Nil,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ } else {
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+ nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+ nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
+ completeAggregateExpressions = Nil,
+ completeAggregateAttributes = Nil,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = partialMergeAggregateResult,
+ child = partialAggregate)
+ }
// 3. Create an Aggregate Operator for partial merge aggregations.
- val finalAggregateExpressions = functionsWithoutDistinct.map {
- case AggregateExpression2(aggregateFunction, mode, _) =>
- AggregateExpression2(aggregateFunction, Final, false)
- }
+ val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
val finalAggregateAttributes =
finalAggregateExpressions.map {
- expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+ expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
}
+ // Create a map to store those rewritten aggregate functions. We always need to use
+ // both function and its corresponding isDistinct flag as the key because function itself
+ // does not knows if it is has distinct keyword or now.
+ val rewrittenAggregateFunctions =
+ mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2]
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
- case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+ case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
val rewrittenAggregateFunction = aggregateFunction.transformDown {
case expr if distinctColumnExpressionMap.contains(expr) =>
distinctColumnExpressionMap(expr).toAttribute
}.asInstanceOf[AggregateFunction2]
+ // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions
+ // to track the old version and the new version of this function.
+ rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
+ // We just keep the isDistinct setting to true, so when users look at the query plan,
+ // they still can see distinct aggregations.
val rewrittenAggregateExpression =
- AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+ AggregateExpression2(rewrittenAggregateFunction, Complete, true)
- val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+ val aggregateFunctionAttribute =
+ aggregateFunctionMap(agg.aggregateFunction, true)._2
(rewrittenAggregateExpression -> aggregateFunctionAttribute)
}.unzip
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transform {
- case agg: AggregateExpression2 =>
- aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
- case expression =>
- // We do not rely on the equality check at here since attributes may
- // different cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
+ val finalAndCompleteAggregate = if (usesTungstenAggregate) {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case agg: AggregateExpression2 =>
+ val function = agg.aggregateFunction
+ val isDistinct = agg.isDistinct
+ val aggregateFunction =
+ if (rewrittenAggregateFunctions.contains(function, isDistinct)) {
+ // If this function has been rewritten, we get the rewritten version from
+ // rewrittenAggregateFunctions.
+ rewrittenAggregateFunctions(function, isDistinct)
+ } else {
+ // Oterwise, we get it from aggregateFunctionMap, which contains unique
+ // aggregate functions that have not been rewritten.
+ aggregateFunctionMap(function, isDistinct)._1
+ }
+ aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ TungstenAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ completeAggregateExpressions = completeAggregateExpressions,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialMergeAggregate)
+ } else {
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transform {
+ case agg: AggregateExpression2 =>
+ aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+ case expression =>
+ // We do not rely on the equality check at here since attributes may
+ // different cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+ SortBasedAggregate(
+ requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+ groupingExpressions = namedGroupingAttributes,
+ nonCompleteAggregateExpressions = finalAggregateExpressions,
+ nonCompleteAggregateAttributes = finalAggregateAttributes,
+ completeAggregateExpressions = completeAggregateExpressions,
+ completeAggregateAttributes = completeAggregateAttributes,
+ initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+ resultExpressions = rewrittenResultExpressions,
+ child = partialMergeAggregate)
}
- val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
- namedGroupingAttributes ++ distinctColumnAttributes,
- namedGroupingAttributes,
- finalAggregateExpressions,
- finalAggregateAttributes,
- completeAggregateExpressions,
- completeAggregateAttributes,
- rewrittenResultExpressions,
- partialMergeAggregate)
finalAndCompleteAggregate :: Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2294a670c735f..247c900baae9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
@@ -26,9 +26,11 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
+import org.apache.spark.util.random.PoissonSampler
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -39,11 +41,20 @@ import org.apache.spark.{HashPartitioner, SparkEnv}
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+ override private[sql] lazy val metrics = Map(
+ "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
+
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
- protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
- val reusableProjection = buildProjection()
- iter.map(reusableProjection)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numRows = longMetric("numRows")
+ child.execute().mapPartitions { iter =>
+ val reusableProjection = buildProjection()
+ iter.map { row =>
+ numRows += 1
+ reusableProjection(row)
+ }
+ }
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -55,19 +66,28 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
*/
case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
+ override private[sql] lazy val metrics = Map(
+ "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
+
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = true
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
- protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
- this.transformAllExpressions {
- case CreateStruct(children) => CreateStructUnsafe(children)
- case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numRows = longMetric("numRows")
+ child.execute().mapPartitions { iter =>
+ this.transformAllExpressions {
+ case CreateStruct(children) => CreateStructUnsafe(children)
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ }
+ val project = UnsafeProjection.create(projectList, child.output)
+ iter.map { row =>
+ numRows += 1
+ project(row)
+ }
}
- val project = UnsafeProjection.create(projectList, child.output)
- iter.map(project)
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -81,8 +101,22 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
- protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
- iter.filter(newPredicate(condition, child.output))
+ private[sql] override lazy val metrics = Map(
+ "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numInputRows = longMetric("numInputRows")
+ val numOutputRows = longMetric("numOutputRows")
+ child.execute().mapPartitions { iter =>
+ val predicate = newPredicate(condition, child.output)
+ iter.filter { row =>
+ numInputRows += 1
+ val r = predicate(row)
+ if (r) numOutputRows += 1
+ r
+ }
+ }
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -115,12 +149,21 @@ case class Sample(
{
override def output: Seq[Attribute] = child.output
- // TODO: How to pick seed?
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
- child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
+ // Disable gap sampling since the gap sampling method buffers two rows internally,
+ // requiring us to copy the row, which is more expensive than the random number generator.
+ new PartitionwiseSampledRDD[InternalRow, InternalRow](
+ child.execute(),
+ new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
+ preservesPartitioning = true,
+ seed)
} else {
- child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
+ child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
}
}
}
@@ -198,7 +241,9 @@ case class TakeOrderedAndProject(
override def outputPartitioning: Partitioning = SinglePartition
- private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
+ // We need to use an interpreted ordering here because generated orderings cannot be serialized
+ // and this ordering needs to be created on the driver in order to be passed into Spark core code.
+ private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
// TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
@transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
@@ -220,7 +265,6 @@ case class TakeOrderedAndProject(
override def outputOrdering: Seq[SortOrder] = sortOrder
}
-
/**
* :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
@@ -230,6 +274,11 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan)
extends UnaryNode {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = {
+ if (numPartitions == 1) SinglePartition
+ else UnknownPartitioning(numPartitions)
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
child.execute().map(_.copy()).coalesce(numPartitions, shuffle)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 6b83025d5a153..95209e6634519 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -69,6 +69,8 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan
val converted = sideEffectResult.map(convert(_).asInstanceOf[InternalRow])
sqlContext.sparkContext.parallelize(converted, 1)
}
+
+ override def argString: String = cmd.toString
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
new file mode 100644
index 0000000000000..6c462fa30461b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
@@ -0,0 +1,185 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.datasources
+
+import scala.language.implicitConversions
+import scala.util.matching.Regex
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types._
+
+
+/**
+ * A parser for foreign DDL commands.
+ */
+class DDLParser(parseQuery: String => LogicalPlan)
+ extends AbstractSparkSQLParser with DataTypeParser with Logging {
+
+ def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
+ try {
+ parse(input)
+ } catch {
+ case ddlException: DDLException => throw ddlException
+ case _ if !exceptionOnError => parseQuery(input)
+ case x: Throwable => throw x
+ }
+ }
+
+ // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
+ // properties via reflection the class in runtime for constructing the SqlLexical object
+ protected val CREATE = Keyword("CREATE")
+ protected val TEMPORARY = Keyword("TEMPORARY")
+ protected val TABLE = Keyword("TABLE")
+ protected val IF = Keyword("IF")
+ protected val NOT = Keyword("NOT")
+ protected val EXISTS = Keyword("EXISTS")
+ protected val USING = Keyword("USING")
+ protected val OPTIONS = Keyword("OPTIONS")
+ protected val DESCRIBE = Keyword("DESCRIBE")
+ protected val EXTENDED = Keyword("EXTENDED")
+ protected val AS = Keyword("AS")
+ protected val COMMENT = Keyword("COMMENT")
+ protected val REFRESH = Keyword("REFRESH")
+
+ protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable
+
+ protected def start: Parser[LogicalPlan] = ddl
+
+ /**
+ * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS]
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
+ * or
+ * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS]
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
+ * or
+ * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS]
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
+ * AS SELECT ...
+ */
+ protected lazy val createTable: Parser[LogicalPlan] = {
+ // TODO: Support database.table.
+ (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~
+ tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
+ case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query =>
+ if (temp.isDefined && allowExisting.isDefined) {
+ throw new DDLException(
+ "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
+ }
+
+ val options = opts.getOrElse(Map.empty[String, String])
+ if (query.isDefined) {
+ if (columns.isDefined) {
+ throw new DDLException(
+ "a CREATE TABLE AS SELECT statement does not allow column definitions.")
+ }
+ // When IF NOT EXISTS clause appears in the query, the save mode will be ignore.
+ val mode = if (allowExisting.isDefined) {
+ SaveMode.Ignore
+ } else if (temp.isDefined) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+
+ val queryPlan = parseQuery(query.get)
+ CreateTableUsingAsSelect(tableName,
+ provider,
+ temp.isDefined,
+ Array.empty[String],
+ mode,
+ options,
+ queryPlan)
+ } else {
+ val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema,
+ provider,
+ temp.isDefined,
+ options,
+ allowExisting.isDefined,
+ managedIfNoPath = false)
+ }
+ }
+ }
+
+ protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
+
+ /*
+ * describe [extended] table avroTable
+ * This will display all columns of table `avroTable` includes column_name,column_type,comment
+ */
+ protected lazy val describeTable: Parser[LogicalPlan] =
+ (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
+ case e ~ db ~ tbl =>
+ val tblIdentifier = db match {
+ case Some(dbName) =>
+ Seq(dbName, tbl)
+ case None =>
+ Seq(tbl)
+ }
+ DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined)
+ }
+
+ protected lazy val refreshTable: Parser[LogicalPlan] =
+ REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ {
+ case maybeDatabaseName ~ tableName =>
+ RefreshTable(TableIdentifier(tableName, maybeDatabaseName))
+ }
+
+ protected lazy val options: Parser[Map[String, String]] =
+ "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
+
+ protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
+
+ override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
+ s"identifier matching regex $regex", {
+ case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str
+ case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str
+ }
+ )
+
+ protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ {
+ case name => name
+ }
+
+ protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ {
+ case parts => parts.mkString(".")
+ }
+
+ protected lazy val pair: Parser[(String, String)] =
+ optionName ~ stringLit ^^ { case k ~ v => (k, v) }
+
+ protected lazy val column: Parser[StructField] =
+ ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm =>
+ val meta = cm match {
+ case Some(comment) =>
+ new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build()
+ case None => Metadata.empty
+ }
+
+ StructField(columnName, typ, nullable = true, meta)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 6b91e51ca52fb..2a4c40db8bb66 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -99,8 +99,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(a, f) =>
toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f, t.paths, confBroadcast))) :: Nil
- case l @ LogicalRelation(t: TableScan) =>
- execution.PhysicalRDD(l.output, toCatalystRDD(l, t.buildScan())) :: Nil
+ case l @ LogicalRelation(baseRelation: TableScan) =>
+ execution.PhysicalRDD.createFromDataSource(
+ l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil
case i @ logical.InsertIntoTable(
l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty =>
@@ -167,7 +168,10 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows)
}
- execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows)
+ execution.PhysicalRDD.createFromDataSource(
+ projections.map(_.toAttribute),
+ unionedRows,
+ logicalRelation.relation)
}
// TODO: refactor this thing. It is very complicated because it does projection internally.
@@ -187,15 +191,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// To see whether the `index`-th column is a partition column...
val i = partitionColumns.indexOf(name)
if (i != -1) {
+ val dt = schema(partitionColumns(i)).dataType
// If yes, gets column value from partition values.
(mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
- mutableRow(ordinal) = partitionValues.genericGet(i)
+ mutableRow(ordinal) = partitionValues.get(i, dt)
}
} else {
// Otherwise, inherits the value from scanned data.
val i = nonPartitionColumns.indexOf(name)
+ val dt = schema(nonPartitionColumns(i)).dataType
(mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => {
- mutableRow(ordinal) = dataRow.genericGet(i)
+ mutableRow(ordinal) = dataRow.get(i, dt)
}
}
}
@@ -295,14 +301,18 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
projects.asInstanceOf[Seq[Attribute]] // Safe due to if above.
.map(relation.attributeMap) // Match original case of attributes.
- val scan = execution.PhysicalRDD(projects.map(_.toAttribute),
- scanBuilder(requestedColumns, pushedFilters))
+ val scan = execution.PhysicalRDD.createFromDataSource(
+ projects.map(_.toAttribute),
+ scanBuilder(requestedColumns, pushedFilters),
+ relation.relation)
filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
} else {
val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
- val scan = execution.PhysicalRDD(requestedColumns,
- scanBuilder(requestedColumns, pushedFilters))
+ val scan = execution.PhysicalRDD.createFromDataSource(
+ requestedColumns,
+ scanBuilder(requestedColumns, pushedFilters),
+ relation.relation)
execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
}
}
@@ -339,6 +349,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case expressions.EqualTo(Literal(v, _), a: Attribute) =>
Some(sources.EqualTo(a.name, v))
+ case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
+ Some(sources.EqualNullSafe(a.name, v))
+ case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
+ Some(sources.EqualNullSafe(a.name, v))
+
case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThan(a.name, v))
case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
@@ -362,6 +377,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
+ // Because we only convert In to InSet in Optimizer when there are more than certain
+ // items. So it is possible we still get an In expression here that needs to be pushed
+ // down.
+ case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+ val hSet = list.map(e => e.eval(EmptyRow))
+ Some(sources.In(a.name, hSet.toArray))
+
case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
case expressions.IsNotNull(a: Attribute) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala
new file mode 100644
index 0000000000000..6e4cc4de7f651
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala
@@ -0,0 +1,64 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.Properties
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JDBCPartitioningInfo, DriverRegistry}
+import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
+
+
+class DefaultSource extends RelationProvider with DataSourceRegister {
+
+ override def shortName(): String = "jdbc"
+
+ /** Returns a new base relation with the given parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
+ val driver = parameters.getOrElse("driver", null)
+ val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
+ val partitionColumn = parameters.getOrElse("partitionColumn", null)
+ val lowerBound = parameters.getOrElse("lowerBound", null)
+ val upperBound = parameters.getOrElse("upperBound", null)
+ val numPartitions = parameters.getOrElse("numPartitions", null)
+
+ if (driver != null) DriverRegistry.register(driver)
+
+ if (partitionColumn != null
+ && (lowerBound == null || upperBound == null || numPartitions == null)) {
+ sys.error("Partitioning incompletely specified")
+ }
+
+ val partitionInfo = if (partitionColumn == null) {
+ null
+ } else {
+ JDBCPartitioningInfo(
+ partitionColumn,
+ lowerBound.toLong,
+ upperBound.toLong,
+ numPartitions.toInt)
+ }
+ val parts = JDBCRelation.columnPartition(partitionInfo)
+ val properties = new Properties() // Additional properties that we will pass to getConnection
+ parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
+ JDBCRelation(url, table, parts, properties)(sqlContext)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala
new file mode 100644
index 0000000000000..3b7dc2e8d0210
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.sql.sources.InsertableRelation
+
+
+/**
+ * Inserts the results of `query` in to a relation that extends [[InsertableRelation]].
+ */
+private[sql] case class InsertIntoDataSource(
+ logicalRelation: LogicalRelation,
+ query: LogicalPlan,
+ overwrite: Boolean)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
+ val data = DataFrame(sqlContext, query)
+ // Apply the schema of the existing table to the new data.
+ val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
+ relation.insert(df, overwrite)
+
+ // Invalidate the cache.
+ sqlContext.cacheManager.invalidateCache(logicalRelation)
+
+ Seq.empty[Row]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
new file mode 100644
index 0000000000000..735d52f808868
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
+import org.apache.spark._
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.{RunnableCommand, SQLExecution}
+import org.apache.spark.sql.sources._
+import org.apache.spark.util.Utils
+
+
+/**
+ * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
+ * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a
+ * single write job, and owns a UUID that identifies this job. Each concrete implementation of
+ * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for
+ * each task output file. This UUID is passed to executor side via a property named
+ * `spark.sql.sources.writeJobUUID`.
+ *
+ * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]]
+ * are used to write to normal tables and tables with dynamic partitions.
+ *
+ * Basic work flow of this command is:
+ *
+ * 1. Driver side setup, including output committer initialization and data source specific
+ * preparation work for the write job to be issued.
+ * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
+ * rows within an RDD partition.
+ * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any
+ * exception is thrown during task commitment, also aborts that task.
+ * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
+ * thrown during job commitment, also aborts the job.
+ */
+private[sql] case class InsertIntoHadoopFsRelation(
+ @transient relation: HadoopFsRelation,
+ @transient query: LogicalPlan,
+ mode: SaveMode)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ require(
+ relation.paths.length == 1,
+ s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
+
+ val hadoopConf = sqlContext.sparkContext.hadoopConfiguration
+ val outputPath = new Path(relation.paths.head)
+ val fs = outputPath.getFileSystem(hadoopConf)
+ val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+
+ val pathExists = fs.exists(qualifiedOutputPath)
+ val doInsertion = (mode, pathExists) match {
+ case (SaveMode.ErrorIfExists, true) =>
+ throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
+ case (SaveMode.Overwrite, true) =>
+ Utils.tryOrIOException {
+ if (!fs.delete(qualifiedOutputPath, true /* recursively */)) {
+ throw new IOException(s"Unable to clear output " +
+ s"directory $qualifiedOutputPath prior to writing to it")
+ }
+ }
+ true
+ case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
+ true
+ case (SaveMode.Ignore, exists) =>
+ !exists
+ case (s, exists) =>
+ throw new IllegalStateException(s"unsupported save mode $s ($exists)")
+ }
+ // If we are appending data to an existing dir.
+ val isAppend = pathExists && (mode == SaveMode.Append)
+
+ if (doInsertion) {
+ val job = new Job(hadoopConf)
+ job.setOutputKeyClass(classOf[Void])
+ job.setOutputValueClass(classOf[InternalRow])
+ FileOutputFormat.setOutputPath(job, qualifiedOutputPath)
+
+ // A partitioned relation schema's can be different from the input logicalPlan, since
+ // partition columns are all moved after data column. We Project to adjust the ordering.
+ // TODO: this belongs in the analyzer.
+ val project = Project(
+ relation.schema.map(field => UnresolvedAttribute.quoted(field.name)), query)
+ val queryExecution = DataFrame(sqlContext, project).queryExecution
+
+ SQLExecution.withNewExecutionId(sqlContext, queryExecution) {
+ val df = sqlContext.internalCreateDataFrame(queryExecution.toRdd, relation.schema)
+ val partitionColumns = relation.partitionColumns.fieldNames
+
+ // Some pre-flight checks.
+ require(
+ df.schema == relation.schema,
+ s"""DataFrame must have the same schema as the relation to which is inserted.
+ |DataFrame schema: ${df.schema}
+ |Relation schema: ${relation.schema}
+ """.stripMargin)
+ val partitionColumnsInSpec = relation.partitionColumns.fieldNames
+ require(
+ partitionColumnsInSpec.sameElements(partitionColumns),
+ s"""Partition columns mismatch.
+ |Expected: ${partitionColumnsInSpec.mkString(", ")}
+ |Actual: ${partitionColumns.mkString(", ")}
+ """.stripMargin)
+
+ val writerContainer = if (partitionColumns.isEmpty) {
+ new DefaultWriterContainer(relation, job, isAppend)
+ } else {
+ val output = df.queryExecution.executedPlan.output
+ val (partitionOutput, dataOutput) =
+ output.partition(a => partitionColumns.contains(a.name))
+
+ new DynamicPartitionWriterContainer(
+ relation,
+ job,
+ partitionOutput,
+ dataOutput,
+ output,
+ PartitioningUtils.DEFAULT_PARTITION_NAME,
+ sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES),
+ isAppend)
+ }
+
+ // This call shouldn't be put into the `try` block below because it only initializes and
+ // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
+ writerContainer.driverSideSetup()
+
+ try {
+ sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writerContainer.writeRows _)
+ writerContainer.commitJob()
+ relation.refresh()
+ } catch { case cause: Throwable =>
+ logError("Aborting job.", cause)
+ writerContainer.abortJob()
+ throw new SparkException("Job aborted.", cause)
+ }
+ }
+ } else {
+ logInfo("Skipping insertion into a relation that already exists.")
+ }
+
+ Seq.empty[Row]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
new file mode 100644
index 0000000000000..7770bbd712f04
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -0,0 +1,204 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.ServiceLoader
+
+import scala.collection.JavaConversions._
+import scala.language.{existentials, implicitConversions}
+import scala.util.{Success, Failure, Try}
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
+import org.apache.spark.util.Utils
+
+
+case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)
+
+
+object ResolvedDataSource extends Logging {
+
+ /** A map to maintain backward compatibility in case we move data sources around. */
+ private val backwardCompatibilityMap = Map(
+ "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName,
+ "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName,
+ "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName,
+ "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName,
+ "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName,
+ "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName
+ )
+
+ /** Given a provider name, look up the data source class definition. */
+ def lookupDataSource(provider0: String): Class[_] = {
+ val provider = backwardCompatibilityMap.getOrElse(provider0, provider0)
+ val provider2 = s"$provider.DefaultSource"
+ val loader = Utils.getContextOrSparkClassLoader
+ val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
+
+ serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match {
+ /** the provider format did not match any given registered aliases */
+ case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match {
+ case Success(dataSource) => dataSource
+ case Failure(error) =>
+ if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
+ throw new ClassNotFoundException(
+ "The ORC data source must be used with Hive support enabled.", error)
+ } else {
+ throw new ClassNotFoundException(
+ s"Failed to load class for data source: $provider.", error)
+ }
+ }
+ /** there is exactly one registered alias */
+ case head :: Nil => head.getClass
+ /** There are multiple registered aliases for the input */
+ case sources => sys.error(s"Multiple sources found for $provider, " +
+ s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
+ "please specify the fully qualified class name.")
+ }
+ }
+
+ /** Create a [[ResolvedDataSource]] for reading data in. */
+ def apply(
+ sqlContext: SQLContext,
+ userSpecifiedSchema: Option[StructType],
+ partitionColumns: Array[String],
+ provider: String,
+ options: Map[String, String]): ResolvedDataSource = {
+ val clazz: Class[_] = lookupDataSource(provider)
+ def className: String = clazz.getCanonicalName
+ val relation = userSpecifiedSchema match {
+ case Some(schema: StructType) => clazz.newInstance() match {
+ case dataSource: SchemaRelationProvider =>
+ dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
+ case dataSource: HadoopFsRelationProvider =>
+ val maybePartitionsSchema = if (partitionColumns.isEmpty) {
+ None
+ } else {
+ Some(partitionColumnsSchema(schema, partitionColumns))
+ }
+
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val paths = {
+ val patternPath = new Path(caseInsensitiveOptions("path"))
+ val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray
+ }
+
+ val dataSchema =
+ StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable
+
+ dataSource.createRelation(
+ sqlContext,
+ paths,
+ Some(dataSchema),
+ maybePartitionsSchema,
+ caseInsensitiveOptions)
+ case dataSource: org.apache.spark.sql.sources.RelationProvider =>
+ throw new AnalysisException(s"$className does not allow user-specified schemas.")
+ case _ =>
+ throw new AnalysisException(s"$className is not a RelationProvider.")
+ }
+
+ case None => clazz.newInstance() match {
+ case dataSource: RelationProvider =>
+ dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
+ case dataSource: HadoopFsRelationProvider =>
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val paths = {
+ val patternPath = new Path(caseInsensitiveOptions("path"))
+ val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray
+ }
+ dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions)
+ case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
+ throw new AnalysisException(
+ s"A schema needs to be specified when using $className.")
+ case _ =>
+ throw new AnalysisException(
+ s"$className is neither a RelationProvider nor a FSBasedRelationProvider.")
+ }
+ }
+ new ResolvedDataSource(clazz, relation)
+ }
+
+ private def partitionColumnsSchema(
+ schema: StructType,
+ partitionColumns: Array[String]): StructType = {
+ StructType(partitionColumns.map { col =>
+ schema.find(_.name == col).getOrElse {
+ throw new RuntimeException(s"Partition column $col not found in schema $schema")
+ }
+ }).asNullable
+ }
+
+ /** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */
+ def apply(
+ sqlContext: SQLContext,
+ provider: String,
+ partitionColumns: Array[String],
+ mode: SaveMode,
+ options: Map[String, String],
+ data: DataFrame): ResolvedDataSource = {
+ if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
+ throw new AnalysisException("Cannot save interval data type into external storage.")
+ }
+ val clazz: Class[_] = lookupDataSource(provider)
+ val relation = clazz.newInstance() match {
+ case dataSource: CreatableRelationProvider =>
+ dataSource.createRelation(sqlContext, mode, options, data)
+ case dataSource: HadoopFsRelationProvider =>
+ // Don't glob path for the write path. The contracts here are:
+ // 1. Only one output path can be specified on the write path;
+ // 2. Output path must be a legal HDFS style file system path;
+ // 3. It's OK that the output path doesn't exist yet;
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val outputPath = {
+ val path = new Path(caseInsensitiveOptions("path"))
+ val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ path.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ }
+ val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name)))
+ val r = dataSource.createRelation(
+ sqlContext,
+ Array(outputPath.toString),
+ Some(dataSchema.asNullable),
+ Some(partitionColumnsSchema(data.schema, partitionColumns)),
+ caseInsensitiveOptions)
+
+ // For partitioned relation r, r.schema's column ordering can be different from the column
+ // ordering of data.logicalPlan (partition columns are all moved after data column). This
+ // will be adjusted within InsertIntoHadoopFsRelation.
+ sqlContext.executePlan(
+ InsertIntoHadoopFsRelation(
+ r,
+ data.logicalPlan,
+ mode)).toRdd
+ r
+ case _ =>
+ sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
+ }
+ ResolvedDataSource(clazz, relation)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
new file mode 100644
index 0000000000000..2f11f40422402
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.{Date, UUID}
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter}
+import org.apache.spark._
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.UnsafeKVExternalSorter
+import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
+import org.apache.spark.sql.types.{StructType, StringType}
+import org.apache.spark.util.SerializableConfiguration
+
+
+private[sql] abstract class BaseWriterContainer(
+ @transient val relation: HadoopFsRelation,
+ @transient job: Job,
+ isAppend: Boolean)
+ extends SparkHadoopMapReduceUtil
+ with Logging
+ with Serializable {
+
+ protected val dataSchema = relation.dataSchema
+
+ protected val serializableConf = new SerializableConfiguration(job.getConfiguration)
+
+ // This UUID is used to avoid output file name collision between different appending write jobs.
+ // These jobs may belong to different SparkContext instances. Concrete data source implementations
+ // may use this UUID to generate unique file names (e.g., `part-r--.parquet`).
+ // The reason why this ID is used to identify a job rather than a single task output file is
+ // that, speculative tasks must generate the same output file name as the original task.
+ private val uniqueWriteJobId = UUID.randomUUID()
+
+ // This is only used on driver side.
+ @transient private val jobContext: JobContext = job
+
+ // The following fields are initialized and used on both driver and executor side.
+ @transient protected var outputCommitter: OutputCommitter = _
+ @transient private var jobId: JobID = _
+ @transient private var taskId: TaskID = _
+ @transient private var taskAttemptId: TaskAttemptID = _
+ @transient protected var taskAttemptContext: TaskAttemptContext = _
+
+ protected val outputPath: String = {
+ assert(
+ relation.paths.length == 1,
+ s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
+ relation.paths.head
+ }
+
+ protected var outputWriterFactory: OutputWriterFactory = _
+
+ private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _
+
+ def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit
+
+ def driverSideSetup(): Unit = {
+ setupIDs(0, 0, 0)
+ setupConf()
+
+ // This UUID is sent to executor side together with the serialized `Configuration` object within
+ // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate
+ // unique task output files.
+ job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString)
+
+ // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor
+ // clones the Configuration object passed in. If we initialize the TaskAttemptContext first,
+ // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext.
+ //
+ // Also, the `prepareJobForWrite` call must happen before initializing output format and output
+ // committer, since their initialization involve the job configuration, which can be potentially
+ // decorated in `prepareJobForWrite`.
+ outputWriterFactory = relation.prepareJobForWrite(job)
+ taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+
+ outputFormatClass = job.getOutputFormatClass
+ outputCommitter = newOutputCommitter(taskAttemptContext)
+ outputCommitter.setupJob(jobContext)
+ }
+
+ def executorSideSetup(taskContext: TaskContext): Unit = {
+ setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
+ setupConf()
+ taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+ outputCommitter = newOutputCommitter(taskAttemptContext)
+ outputCommitter.setupTask(taskAttemptContext)
+ }
+
+ protected def getWorkPath: String = {
+ outputCommitter match {
+ // FileOutputCommitter writes to a temporary location returned by `getWorkPath`.
+ case f: MapReduceFileOutputCommitter => f.getWorkPath.toString
+ case _ => outputPath
+ }
+ }
+
+ private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = {
+ val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context)
+
+ if (isAppend) {
+ // If we are appending data to an existing dir, we will only use the output committer
+ // associated with the file output format since it is not safe to use a custom
+ // committer for appending. For example, in S3, direct parquet output committer may
+ // leave partial data in the destination dir when the the appending job fails.
+ logInfo(
+ s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " +
+ "for appending.")
+ defaultOutputCommitter
+ } else {
+ val committerClass = context.getConfiguration.getClass(
+ SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])
+
+ Option(committerClass).map { clazz =>
+ logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
+
+ // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
+ // has an associated output committer. To override this output committer,
+ // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
+ // If a data source needs to override the output committer, it needs to set the
+ // output committer in prepareForWrite method.
+ if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) {
+ // The specified output committer is a FileOutputCommitter.
+ // So, we will use the FileOutputCommitter-specified constructor.
+ val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
+ ctor.newInstance(new Path(outputPath), context)
+ } else {
+ // The specified output committer is just a OutputCommitter.
+ // So, we will use the no-argument constructor.
+ val ctor = clazz.getDeclaredConstructor()
+ ctor.newInstance()
+ }
+ }.getOrElse {
+ // If output committer class is not set, we will use the one associated with the
+ // file output format.
+ logInfo(
+ s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}")
+ defaultOutputCommitter
+ }
+ }
+ }
+
+ private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = {
+ this.jobId = SparkHadoopWriter.createJobID(new Date, jobId)
+ this.taskId = new TaskID(this.jobId, true, splitId)
+ this.taskAttemptId = new TaskAttemptID(taskId, attemptId)
+ }
+
+ private def setupConf(): Unit = {
+ serializableConf.value.set("mapred.job.id", jobId.toString)
+ serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
+ serializableConf.value.set("mapred.task.id", taskAttemptId.toString)
+ serializableConf.value.setBoolean("mapred.task.is.map", true)
+ serializableConf.value.setInt("mapred.task.partition", 0)
+ }
+
+ def commitTask(): Unit = {
+ SparkHadoopMapRedUtil.commitTask(
+ outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId)
+ }
+
+ def abortTask(): Unit = {
+ if (outputCommitter != null) {
+ outputCommitter.abortTask(taskAttemptContext)
+ }
+ logError(s"Task attempt $taskAttemptId aborted.")
+ }
+
+ def commitJob(): Unit = {
+ outputCommitter.commitJob(jobContext)
+ logInfo(s"Job $jobId committed.")
+ }
+
+ def abortJob(): Unit = {
+ if (outputCommitter != null) {
+ outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
+ }
+ logError(s"Job $jobId aborted.")
+ }
+}
+
+/**
+ * A writer that writes all of the rows in a partition to a single file.
+ */
+private[sql] class DefaultWriterContainer(
+ @transient relation: HadoopFsRelation,
+ @transient job: Job,
+ isAppend: Boolean)
+ extends BaseWriterContainer(relation, job, isAppend) {
+
+ def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
+ executorSideSetup(taskContext)
+ taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath)
+ val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext)
+ writer.initConverter(dataSchema)
+
+ // If anything below fails, we should abort the task.
+ try {
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ writer.writeInternal(internalRow)
+ }
+
+ commitTask()
+ } catch {
+ case cause: Throwable =>
+ logError("Aborting task.", cause)
+ abortTask()
+ throw new SparkException("Task failed while writing rows.", cause)
+ }
+
+ def commitTask(): Unit = {
+ try {
+ assert(writer != null, "OutputWriter instance should have been initialized")
+ writer.close()
+ super.commitTask()
+ } catch {
+ case cause: Throwable =>
+ // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
+ // will cause `abortTask()` to be invoked.
+ throw new RuntimeException("Failed to commit task", cause)
+ }
+ }
+
+ def abortTask(): Unit = {
+ try {
+ writer.close()
+ } finally {
+ super.abortTask()
+ }
+ }
+ }
+}
+
+/**
+ * A writer that dynamically opens files based on the given partition columns. Internally this is
+ * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the
+ * writer externally sorts the remaining rows and then writes out them out one file at a time.
+ */
+private[sql] class DynamicPartitionWriterContainer(
+ @transient relation: HadoopFsRelation,
+ @transient job: Job,
+ partitionColumns: Seq[Attribute],
+ dataColumns: Seq[Attribute],
+ inputSchema: Seq[Attribute],
+ defaultPartitionName: String,
+ maxOpenFiles: Int,
+ isAppend: Boolean)
+ extends BaseWriterContainer(relation, job, isAppend) {
+
+ def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
+ val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
+ executorSideSetup(taskContext)
+
+ // Returns the partition key given an input row
+ val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema)
+ // Returns the data columns to be written given an input row
+ val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
+
+ // Expressions that given a partition key build a string like: col1=val/col2=val/...
+ val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+ val escaped =
+ ScalaUDF(
+ PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType))
+ val str = If(IsNull(c), Literal(defaultPartitionName), escaped)
+ val partitionName = Literal(c.name + "=") :: str :: Nil
+ if (i == 0) partitionName else Literal(Path.SEPARATOR_CHAR.toString) :: partitionName
+ }
+
+ // Returns the partition path given a partition key.
+ val getPartitionString =
+ UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
+
+ // If anything below fails, we should abort the task.
+ try {
+ // This will be filled in if we have to fall back on sorting.
+ var sorter: UnsafeKVExternalSorter = null
+ while (iterator.hasNext && sorter == null) {
+ val inputRow = iterator.next()
+ val currentKey = getPartitionKey(inputRow)
+ var currentWriter = outputWriters.get(currentKey)
+
+ if (currentWriter == null) {
+ if (outputWriters.size < maxOpenFiles) {
+ currentWriter = newOutputWriter(currentKey)
+ outputWriters.put(currentKey.copy(), currentWriter)
+ currentWriter.writeInternal(getOutputRow(inputRow))
+ } else {
+ logInfo(s"Maximum partitions reached, falling back on sorting.")
+ sorter = new UnsafeKVExternalSorter(
+ StructType.fromAttributes(partitionColumns),
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.shuffleMemoryManager,
+ SparkEnv.get.shuffleMemoryManager.pageSizeBytes)
+ sorter.insertKV(currentKey, getOutputRow(inputRow))
+ }
+ } else {
+ currentWriter.writeInternal(getOutputRow(inputRow))
+ }
+ }
+
+ // If the sorter is not null that means that we reached the maxFiles above and need to finish
+ // using external sort.
+ if (sorter != null) {
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow))
+ }
+
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
+
+ val sortedIterator = sorter.sortedIterator()
+ var currentKey: InternalRow = null
+ var currentWriter: OutputWriter = null
+ try {
+ while (sortedIterator.next()) {
+ if (currentKey != sortedIterator.getKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
+ currentKey = sortedIterator.getKey.copy()
+ logDebug(s"Writing partition: $currentKey")
+
+ // Either use an existing file from before, or open a new one.
+ currentWriter = outputWriters.remove(currentKey)
+ if (currentWriter == null) {
+ currentWriter = newOutputWriter(currentKey)
+ }
+ }
+
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ } finally {
+ if (currentWriter != null) { currentWriter.close() }
+ }
+ }
+
+ commitTask()
+ } catch {
+ case cause: Throwable =>
+ logError("Aborting task.", cause)
+ abortTask()
+ throw new SparkException("Task failed while writing rows.", cause)
+ }
+
+ /** Open and returns a new OutputWriter given a partition key. */
+ def newOutputWriter(key: InternalRow): OutputWriter = {
+ val partitionPath = getPartitionString(key).getString(0)
+ val path = new Path(getWorkPath, partitionPath)
+ taskAttemptContext.getConfiguration.set(
+ "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString)
+ val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext)
+ newWriter.initConverter(dataSchema)
+ newWriter
+ }
+
+ def clearOutputWriters(): Unit = {
+ outputWriters.asScala.values.foreach(_.close())
+ outputWriters.clear()
+ }
+
+ def commitTask(): Unit = {
+ try {
+ clearOutputWriters()
+ super.commitTask()
+ } catch {
+ case cause: Throwable =>
+ throw new RuntimeException("Failed to commit task", cause)
+ }
+ }
+
+ def abortTask(): Unit = {
+ try {
+ clearOutputWriters()
+ } finally {
+ super.abortTask()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala
deleted file mode 100644
index d551f386eee6e..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala
+++ /dev/null
@@ -1,599 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources
-
-import java.util.{Date, UUID}
-
-import scala.collection.JavaConversions.asScalaIterator
-
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat}
-import org.apache.spark._
-import org.apache.spark.mapred.SparkHadoopMapRedUtil
-import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.StringType
-import org.apache.spark.util.SerializableConfiguration
-
-
-private[sql] case class InsertIntoDataSource(
- logicalRelation: LogicalRelation,
- query: LogicalPlan,
- overwrite: Boolean)
- extends RunnableCommand {
-
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
- val data = DataFrame(sqlContext, query)
- // Apply the schema of the existing table to the new data.
- val df = sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
- relation.insert(df, overwrite)
-
- // Invalidate the cache.
- sqlContext.cacheManager.invalidateCache(logicalRelation)
-
- Seq.empty[Row]
- }
-}
-
-/**
- * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
- * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a
- * single write job, and owns a UUID that identifies this job. Each concrete implementation of
- * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for
- * each task output file. This UUID is passed to executor side via a property named
- * `spark.sql.sources.writeJobUUID`.
- *
- * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]]
- * are used to write to normal tables and tables with dynamic partitions.
- *
- * Basic work flow of this command is:
- *
- * 1. Driver side setup, including output committer initialization and data source specific
- * preparation work for the write job to be issued.
- * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
- * rows within an RDD partition.
- * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any
- * exception is thrown during task commitment, also aborts that task.
- * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
- * thrown during job commitment, also aborts the job.
- */
-private[sql] case class InsertIntoHadoopFsRelation(
- @transient relation: HadoopFsRelation,
- @transient query: LogicalPlan,
- mode: SaveMode)
- extends RunnableCommand {
-
- override def run(sqlContext: SQLContext): Seq[Row] = {
- require(
- relation.paths.length == 1,
- s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
-
- val hadoopConf = sqlContext.sparkContext.hadoopConfiguration
- val outputPath = new Path(relation.paths.head)
- val fs = outputPath.getFileSystem(hadoopConf)
- val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
-
- val pathExists = fs.exists(qualifiedOutputPath)
- val doInsertion = (mode, pathExists) match {
- case (SaveMode.ErrorIfExists, true) =>
- throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
- case (SaveMode.Overwrite, true) =>
- fs.delete(qualifiedOutputPath, true)
- true
- case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
- true
- case (SaveMode.Ignore, exists) =>
- !exists
- case (s, exists) =>
- throw new IllegalStateException(s"unsupported save mode $s ($exists)")
- }
- // If we are appending data to an existing dir.
- val isAppend = pathExists && (mode == SaveMode.Append)
-
- if (doInsertion) {
- val job = new Job(hadoopConf)
- job.setOutputKeyClass(classOf[Void])
- job.setOutputValueClass(classOf[InternalRow])
- FileOutputFormat.setOutputPath(job, qualifiedOutputPath)
-
- // We create a DataFrame by applying the schema of relation to the data to make sure.
- // We are writing data based on the expected schema,
- val df = {
- // For partitioned relation r, r.schema's column ordering can be different from the column
- // ordering of data.logicalPlan (partition columns are all moved after data column). We
- // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can
- // safely apply the schema of r.schema to the data.
- val project = Project(
- relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query)
-
- sqlContext.internalCreateDataFrame(
- DataFrame(sqlContext, project).queryExecution.toRdd, relation.schema)
- }
-
- val partitionColumns = relation.partitionColumns.fieldNames
- if (partitionColumns.isEmpty) {
- insert(new DefaultWriterContainer(relation, job, isAppend), df)
- } else {
- val writerContainer = new DynamicPartitionWriterContainer(
- relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend)
- insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns)
- }
- }
-
- Seq.empty[Row]
- }
-
- /**
- * Inserts the content of the [[DataFrame]] into a table without any partitioning columns.
- */
- private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
- // Uses local vals for serialization
- val needsConversion = relation.needConversion
- val dataSchema = relation.dataSchema
-
- // This call shouldn't be put into the `try` block below because it only initializes and
- // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
- writerContainer.driverSideSetup()
-
- try {
- df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _)
- writerContainer.commitJob()
- relation.refresh()
- } catch { case cause: Throwable =>
- logError("Aborting job.", cause)
- writerContainer.abortJob()
- throw new SparkException("Job aborted.", cause)
- }
-
- def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- // If anything below fails, we should abort the task.
- try {
- writerContainer.executorSideSetup(taskContext)
-
- if (needsConversion) {
- val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
- .asInstanceOf[InternalRow => Row]
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- writerContainer.outputWriterForRow(internalRow).write(converter(internalRow))
- }
- } else {
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- writerContainer.outputWriterForRow(internalRow)
- .asInstanceOf[OutputWriterInternal].writeInternal(internalRow)
- }
- }
-
- writerContainer.commitTask()
- } catch { case cause: Throwable =>
- logError("Aborting task.", cause)
- writerContainer.abortTask()
- throw new SparkException("Task failed while writing rows.", cause)
- }
- }
- }
-
- /**
- * Inserts the content of the [[DataFrame]] into a table with partitioning columns.
- */
- private def insertWithDynamicPartitions(
- sqlContext: SQLContext,
- writerContainer: BaseWriterContainer,
- df: DataFrame,
- partitionColumns: Array[String]): Unit = {
- // Uses a local val for serialization
- val needsConversion = relation.needConversion
- val dataSchema = relation.dataSchema
-
- require(
- df.schema == relation.schema,
- s"""DataFrame must have the same schema as the relation to which is inserted.
- |DataFrame schema: ${df.schema}
- |Relation schema: ${relation.schema}
- """.stripMargin)
-
- val partitionColumnsInSpec = relation.partitionColumns.fieldNames
- require(
- partitionColumnsInSpec.sameElements(partitionColumns),
- s"""Partition columns mismatch.
- |Expected: ${partitionColumnsInSpec.mkString(", ")}
- |Actual: ${partitionColumns.mkString(", ")}
- """.stripMargin)
-
- val output = df.queryExecution.executedPlan.output
- val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name))
- val codegenEnabled = df.sqlContext.conf.codegenEnabled
-
- // This call shouldn't be put into the `try` block below because it only initializes and
- // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
- writerContainer.driverSideSetup()
-
- try {
- df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _)
- writerContainer.commitJob()
- relation.refresh()
- } catch { case cause: Throwable =>
- logError("Aborting job.", cause)
- writerContainer.abortJob()
- throw new SparkException("Job aborted.", cause)
- }
-
- def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- // If anything below fails, we should abort the task.
- try {
- writerContainer.executorSideSetup(taskContext)
-
- // Projects all partition columns and casts them to strings to build partition directories.
- val partitionCasts = partitionOutput.map(Cast(_, StringType))
- val partitionProj = newProjection(codegenEnabled, partitionCasts, output)
- val dataProj = newProjection(codegenEnabled, dataOutput, output)
-
- if (needsConversion) {
- val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
- .asInstanceOf[InternalRow => Row]
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- val partitionPart = partitionProj(internalRow)
- val dataPart = converter(dataProj(internalRow))
- writerContainer.outputWriterForRow(partitionPart).write(dataPart)
- }
- } else {
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- val partitionPart = partitionProj(internalRow)
- val dataPart = dataProj(internalRow)
- writerContainer.outputWriterForRow(partitionPart)
- .asInstanceOf[OutputWriterInternal].writeInternal(dataPart)
- }
- }
-
- writerContainer.commitTask()
- } catch { case cause: Throwable =>
- logError("Aborting task.", cause)
- writerContainer.abortTask()
- throw new SparkException("Task failed while writing rows.", cause)
- }
- }
- }
-
- // This is copied from SparkPlan, probably should move this to a more general place.
- private def newProjection(
- codegenEnabled: Boolean,
- expressions: Seq[Expression],
- inputSchema: Seq[Attribute]): Projection = {
- log.debug(
- s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
- if (codegenEnabled) {
-
- try {
- GenerateProjection.generate(expressions, inputSchema)
- } catch {
- case e: Exception =>
- if (sys.props.contains("spark.testing")) {
- throw e
- } else {
- log.error("failed to generate projection, fallback to interpreted", e)
- new InterpretedProjection(expressions, inputSchema)
- }
- }
- } else {
- new InterpretedProjection(expressions, inputSchema)
- }
- }
-}
-
-private[sql] abstract class BaseWriterContainer(
- @transient val relation: HadoopFsRelation,
- @transient job: Job,
- isAppend: Boolean)
- extends SparkHadoopMapReduceUtil
- with Logging
- with Serializable {
-
- protected val serializableConf = new SerializableConfiguration(job.getConfiguration)
-
- // This UUID is used to avoid output file name collision between different appending write jobs.
- // These jobs may belong to different SparkContext instances. Concrete data source implementations
- // may use this UUID to generate unique file names (e.g., `part-r--.parquet`).
- // The reason why this ID is used to identify a job rather than a single task output file is
- // that, speculative tasks must generate the same output file name as the original task.
- private val uniqueWriteJobId = UUID.randomUUID()
-
- // This is only used on driver side.
- @transient private val jobContext: JobContext = job
-
- // The following fields are initialized and used on both driver and executor side.
- @transient protected var outputCommitter: OutputCommitter = _
- @transient private var jobId: JobID = _
- @transient private var taskId: TaskID = _
- @transient private var taskAttemptId: TaskAttemptID = _
- @transient protected var taskAttemptContext: TaskAttemptContext = _
-
- protected val outputPath: String = {
- assert(
- relation.paths.length == 1,
- s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}")
- relation.paths.head
- }
-
- protected val dataSchema = relation.dataSchema
-
- protected var outputWriterFactory: OutputWriterFactory = _
-
- private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _
-
- def driverSideSetup(): Unit = {
- setupIDs(0, 0, 0)
- setupConf()
-
- // This UUID is sent to executor side together with the serialized `Configuration` object within
- // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate
- // unique task output files.
- job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString)
-
- // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor
- // clones the Configuration object passed in. If we initialize the TaskAttemptContext first,
- // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext.
- //
- // Also, the `prepareJobForWrite` call must happen before initializing output format and output
- // committer, since their initialization involve the job configuration, which can be potentially
- // decorated in `prepareJobForWrite`.
- outputWriterFactory = relation.prepareJobForWrite(job)
- taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
-
- outputFormatClass = job.getOutputFormatClass
- outputCommitter = newOutputCommitter(taskAttemptContext)
- outputCommitter.setupJob(jobContext)
- }
-
- def executorSideSetup(taskContext: TaskContext): Unit = {
- setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
- setupConf()
- taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
- outputCommitter = newOutputCommitter(taskAttemptContext)
- outputCommitter.setupTask(taskAttemptContext)
- initWriters()
- }
-
- protected def getWorkPath: String = {
- outputCommitter match {
- // FileOutputCommitter writes to a temporary location returned by `getWorkPath`.
- case f: MapReduceFileOutputCommitter => f.getWorkPath.toString
- case _ => outputPath
- }
- }
-
- private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = {
- val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context)
-
- if (isAppend) {
- // If we are appending data to an existing dir, we will only use the output committer
- // associated with the file output format since it is not safe to use a custom
- // committer for appending. For example, in S3, direct parquet output committer may
- // leave partial data in the destination dir when the the appending job fails.
- logInfo(
- s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " +
- "for appending.")
- defaultOutputCommitter
- } else {
- val committerClass = context.getConfiguration.getClass(
- SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])
-
- Option(committerClass).map { clazz =>
- logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
-
- // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
- // has an associated output committer. To override this output committer,
- // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
- // If a data source needs to override the output committer, it needs to set the
- // output committer in prepareForWrite method.
- if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) {
- // The specified output committer is a FileOutputCommitter.
- // So, we will use the FileOutputCommitter-specified constructor.
- val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
- ctor.newInstance(new Path(outputPath), context)
- } else {
- // The specified output committer is just a OutputCommitter.
- // So, we will use the no-argument constructor.
- val ctor = clazz.getDeclaredConstructor()
- ctor.newInstance()
- }
- }.getOrElse {
- // If output committer class is not set, we will use the one associated with the
- // file output format.
- logInfo(
- s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}")
- defaultOutputCommitter
- }
- }
- }
-
- private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = {
- this.jobId = SparkHadoopWriter.createJobID(new Date, jobId)
- this.taskId = new TaskID(this.jobId, true, splitId)
- this.taskAttemptId = new TaskAttemptID(taskId, attemptId)
- }
-
- private def setupConf(): Unit = {
- serializableConf.value.set("mapred.job.id", jobId.toString)
- serializableConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
- serializableConf.value.set("mapred.task.id", taskAttemptId.toString)
- serializableConf.value.setBoolean("mapred.task.is.map", true)
- serializableConf.value.setInt("mapred.task.partition", 0)
- }
-
- // Called on executor side when writing rows
- def outputWriterForRow(row: InternalRow): OutputWriter
-
- protected def initWriters(): Unit
-
- def commitTask(): Unit = {
- SparkHadoopMapRedUtil.commitTask(
- outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId)
- }
-
- def abortTask(): Unit = {
- if (outputCommitter != null) {
- outputCommitter.abortTask(taskAttemptContext)
- }
- logError(s"Task attempt $taskAttemptId aborted.")
- }
-
- def commitJob(): Unit = {
- outputCommitter.commitJob(jobContext)
- logInfo(s"Job $jobId committed.")
- }
-
- def abortJob(): Unit = {
- if (outputCommitter != null) {
- outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
- }
- logError(s"Job $jobId aborted.")
- }
-}
-
-private[sql] class DefaultWriterContainer(
- @transient relation: HadoopFsRelation,
- @transient job: Job,
- isAppend: Boolean)
- extends BaseWriterContainer(relation, job, isAppend) {
-
- @transient private var writer: OutputWriter = _
-
- override protected def initWriters(): Unit = {
- taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath)
- writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext)
- }
-
- override def outputWriterForRow(row: InternalRow): OutputWriter = writer
-
- override def commitTask(): Unit = {
- try {
- assert(writer != null, "OutputWriter instance should have been initialized")
- writer.close()
- super.commitTask()
- } catch { case cause: Throwable =>
- // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will
- // cause `abortTask()` to be invoked.
- throw new RuntimeException("Failed to commit task", cause)
- }
- }
-
- override def abortTask(): Unit = {
- try {
- // It's possible that the task fails before `writer` gets initialized
- if (writer != null) {
- writer.close()
- }
- } finally {
- super.abortTask()
- }
- }
-}
-
-private[sql] class DynamicPartitionWriterContainer(
- @transient relation: HadoopFsRelation,
- @transient job: Job,
- partitionColumns: Array[String],
- defaultPartitionName: String,
- isAppend: Boolean)
- extends BaseWriterContainer(relation, job, isAppend) {
-
- // All output writers are created on executor side.
- @transient protected var outputWriters: java.util.HashMap[String, OutputWriter] = _
-
- override protected def initWriters(): Unit = {
- outputWriters = new java.util.HashMap[String, OutputWriter]
- }
-
- // The `row` argument is supposed to only contain partition column values which have been casted
- // to strings.
- override def outputWriterForRow(row: InternalRow): OutputWriter = {
- val partitionPath = {
- val partitionPathBuilder = new StringBuilder
- var i = 0
-
- while (i < partitionColumns.length) {
- val col = partitionColumns(i)
- val partitionValueString = {
- val string = row.getUTF8String(i)
- if (string.eq(null)) {
- defaultPartitionName
- } else {
- PartitioningUtils.escapePathName(string.toString)
- }
- }
-
- if (i > 0) {
- partitionPathBuilder.append(Path.SEPARATOR_CHAR)
- }
-
- partitionPathBuilder.append(s"$col=$partitionValueString")
- i += 1
- }
-
- partitionPathBuilder.toString()
- }
-
- val writer = outputWriters.get(partitionPath)
- if (writer.eq(null)) {
- val path = new Path(getWorkPath, partitionPath)
- taskAttemptContext.getConfiguration.set(
- "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString)
- val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext)
- outputWriters.put(partitionPath, newWriter)
- newWriter
- } else {
- writer
- }
- }
-
- private def clearOutputWriters(): Unit = {
- if (!outputWriters.isEmpty) {
- asScalaIterator(outputWriters.values().iterator()).foreach(_.close())
- outputWriters.clear()
- }
- }
-
- override def commitTask(): Unit = {
- try {
- clearOutputWriters()
- super.commitTask()
- } catch { case cause: Throwable =>
- throw new RuntimeException("Failed to commit task", cause)
- }
- }
-
- override def abortTask(): Unit = {
- try {
- clearOutputWriters()
- } finally {
- super.abortTask()
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 0cdb407ad57b9..ecd304c30cdee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -17,340 +17,12 @@
package org.apache.spark.sql.execution.datasources
-import scala.language.{existentials, implicitConversions}
-import scala.util.matching.Regex
-
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark.Logging
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier}
import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode}
-import org.apache.spark.util.Utils
-
-/**
- * A parser for foreign DDL commands.
- */
-private[sql] class DDLParser(
- parseQuery: String => LogicalPlan)
- extends AbstractSparkSQLParser with DataTypeParser with Logging {
-
- def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
- try {
- parse(input)
- } catch {
- case ddlException: DDLException => throw ddlException
- case _ if !exceptionOnError => parseQuery(input)
- case x: Throwable => throw x
- }
- }
-
- // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
- // properties via reflection the class in runtime for constructing the SqlLexical object
- protected val CREATE = Keyword("CREATE")
- protected val TEMPORARY = Keyword("TEMPORARY")
- protected val TABLE = Keyword("TABLE")
- protected val IF = Keyword("IF")
- protected val NOT = Keyword("NOT")
- protected val EXISTS = Keyword("EXISTS")
- protected val USING = Keyword("USING")
- protected val OPTIONS = Keyword("OPTIONS")
- protected val DESCRIBE = Keyword("DESCRIBE")
- protected val EXTENDED = Keyword("EXTENDED")
- protected val AS = Keyword("AS")
- protected val COMMENT = Keyword("COMMENT")
- protected val REFRESH = Keyword("REFRESH")
-
- protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable
-
- protected def start: Parser[LogicalPlan] = ddl
-
- /**
- * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS]
- * USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
- * or
- * `CREATE [TEMPORARY] TABLE avroTable(intField int, stringField string...) [IF NOT EXISTS]
- * USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
- * or
- * `CREATE [TEMPORARY] TABLE avroTable [IF NOT EXISTS]
- * USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
- * AS SELECT ...
- */
- protected lazy val createTable: Parser[LogicalPlan] =
- // TODO: Support database.table.
- (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~
- tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
- case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query =>
- if (temp.isDefined && allowExisting.isDefined) {
- throw new DDLException(
- "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
- }
-
- val options = opts.getOrElse(Map.empty[String, String])
- if (query.isDefined) {
- if (columns.isDefined) {
- throw new DDLException(
- "a CREATE TABLE AS SELECT statement does not allow column definitions.")
- }
- // When IF NOT EXISTS clause appears in the query, the save mode will be ignore.
- val mode = if (allowExisting.isDefined) {
- SaveMode.Ignore
- } else if (temp.isDefined) {
- SaveMode.Overwrite
- } else {
- SaveMode.ErrorIfExists
- }
-
- val queryPlan = parseQuery(query.get)
- CreateTableUsingAsSelect(tableName,
- provider,
- temp.isDefined,
- Array.empty[String],
- mode,
- options,
- queryPlan)
- } else {
- val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
- CreateTableUsing(
- tableName,
- userSpecifiedSchema,
- provider,
- temp.isDefined,
- options,
- allowExisting.isDefined,
- managedIfNoPath = false)
- }
- }
-
- protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
-
- /*
- * describe [extended] table avroTable
- * This will display all columns of table `avroTable` includes column_name,column_type,comment
- */
- protected lazy val describeTable: Parser[LogicalPlan] =
- (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
- case e ~ db ~ tbl =>
- val tblIdentifier = db match {
- case Some(dbName) =>
- Seq(dbName, tbl)
- case None =>
- Seq(tbl)
- }
- DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined)
- }
-
- protected lazy val refreshTable: Parser[LogicalPlan] =
- REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ {
- case maybeDatabaseName ~ tableName =>
- RefreshTable(TableIdentifier(tableName, maybeDatabaseName))
- }
-
- protected lazy val options: Parser[Map[String, String]] =
- "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
-
- protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")}
-
- override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch(
- s"identifier matching regex $regex", {
- case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str
- case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str
- }
- )
-
- protected lazy val optionPart: Parser[String] = "[_a-zA-Z][_a-zA-Z0-9]*".r ^^ {
- case name => name
- }
-
- protected lazy val optionName: Parser[String] = repsep(optionPart, ".") ^^ {
- case parts => parts.mkString(".")
- }
-
- protected lazy val pair: Parser[(String, String)] =
- optionName ~ stringLit ^^ { case k ~ v => (k, v) }
-
- protected lazy val column: Parser[StructField] =
- ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm =>
- val meta = cm match {
- case Some(comment) =>
- new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build()
- case None => Metadata.empty
- }
-
- StructField(columnName, typ, nullable = true, meta)
- }
-}
-
-private[sql] object ResolvedDataSource {
-
- private val builtinSources = Map(
- "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource",
- "json" -> "org.apache.spark.sql.json.DefaultSource",
- "parquet" -> "org.apache.spark.sql.parquet.DefaultSource",
- "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource"
- )
-
- /** Given a provider name, look up the data source class definition. */
- def lookupDataSource(provider: String): Class[_] = {
- val loader = Utils.getContextOrSparkClassLoader
-
- if (builtinSources.contains(provider)) {
- return loader.loadClass(builtinSources(provider))
- }
-
- try {
- loader.loadClass(provider)
- } catch {
- case cnf: java.lang.ClassNotFoundException =>
- try {
- loader.loadClass(provider + ".DefaultSource")
- } catch {
- case cnf: java.lang.ClassNotFoundException =>
- if (provider.startsWith("org.apache.spark.sql.hive.orc")) {
- sys.error("The ORC data source must be used with Hive support enabled.")
- } else {
- sys.error(s"Failed to load class for data source: $provider")
- }
- }
- }
- }
-
- /** Create a [[ResolvedDataSource]] for reading data in. */
- def apply(
- sqlContext: SQLContext,
- userSpecifiedSchema: Option[StructType],
- partitionColumns: Array[String],
- provider: String,
- options: Map[String, String]): ResolvedDataSource = {
- val clazz: Class[_] = lookupDataSource(provider)
- def className: String = clazz.getCanonicalName
- val relation = userSpecifiedSchema match {
- case Some(schema: StructType) => clazz.newInstance() match {
- case dataSource: SchemaRelationProvider =>
- dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
- case dataSource: HadoopFsRelationProvider =>
- val maybePartitionsSchema = if (partitionColumns.isEmpty) {
- None
- } else {
- Some(partitionColumnsSchema(schema, partitionColumns))
- }
-
- val caseInsensitiveOptions = new CaseInsensitiveMap(options)
- val paths = {
- val patternPath = new Path(caseInsensitiveOptions("path"))
- val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray
- }
-
- val dataSchema =
- StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable
-
- dataSource.createRelation(
- sqlContext,
- paths,
- Some(dataSchema),
- maybePartitionsSchema,
- caseInsensitiveOptions)
- case dataSource: org.apache.spark.sql.sources.RelationProvider =>
- throw new AnalysisException(s"$className does not allow user-specified schemas.")
- case _ =>
- throw new AnalysisException(s"$className is not a RelationProvider.")
- }
-
- case None => clazz.newInstance() match {
- case dataSource: RelationProvider =>
- dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
- case dataSource: HadoopFsRelationProvider =>
- val caseInsensitiveOptions = new CaseInsensitiveMap(options)
- val paths = {
- val patternPath = new Path(caseInsensitiveOptions("path"))
- val fs = patternPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- val qualifiedPattern = patternPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- SparkHadoopUtil.get.globPathIfNecessary(qualifiedPattern).map(_.toString).toArray
- }
- dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions)
- case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
- throw new AnalysisException(
- s"A schema needs to be specified when using $className.")
- case _ =>
- throw new AnalysisException(
- s"$className is neither a RelationProvider nor a FSBasedRelationProvider.")
- }
- }
- new ResolvedDataSource(clazz, relation)
- }
-
- private def partitionColumnsSchema(
- schema: StructType,
- partitionColumns: Array[String]): StructType = {
- StructType(partitionColumns.map { col =>
- schema.find(_.name == col).getOrElse {
- throw new RuntimeException(s"Partition column $col not found in schema $schema")
- }
- }).asNullable
- }
-
- /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */
- def apply(
- sqlContext: SQLContext,
- provider: String,
- partitionColumns: Array[String],
- mode: SaveMode,
- options: Map[String, String],
- data: DataFrame): ResolvedDataSource = {
- if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
- throw new AnalysisException("Cannot save interval data type into external storage.")
- }
- val clazz: Class[_] = lookupDataSource(provider)
- val relation = clazz.newInstance() match {
- case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sqlContext, mode, options, data)
- case dataSource: HadoopFsRelationProvider =>
- // Don't glob path for the write path. The contracts here are:
- // 1. Only one output path can be specified on the write path;
- // 2. Output path must be a legal HDFS style file system path;
- // 3. It's OK that the output path doesn't exist yet;
- val caseInsensitiveOptions = new CaseInsensitiveMap(options)
- val outputPath = {
- val path = new Path(caseInsensitiveOptions("path"))
- val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- path.makeQualified(fs.getUri, fs.getWorkingDirectory)
- }
- val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name)))
- val r = dataSource.createRelation(
- sqlContext,
- Array(outputPath.toString),
- Some(dataSchema.asNullable),
- Some(partitionColumnsSchema(data.schema, partitionColumns)),
- caseInsensitiveOptions)
-
- // For partitioned relation r, r.schema's column ordering can be different from the column
- // ordering of data.logicalPlan (partition columns are all moved after data column). This
- // will be adjusted within InsertIntoHadoopFsRelation.
- sqlContext.executePlan(
- InsertIntoHadoopFsRelation(
- r,
- data.logicalPlan,
- mode)).toRdd
- r
- case _ =>
- sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
- }
- new ResolvedDataSource(clazz, relation)
- }
-}
-
-private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
/**
* Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command.
@@ -358,11 +30,12 @@ private[sql] case class ResolvedDataSource(provider: Class[_], relation: BaseRel
* @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false.
* It is effective only when the table is a Hive table.
*/
-private[sql] case class DescribeCommand(
+case class DescribeCommand(
table: LogicalPlan,
isExtended: Boolean) extends LogicalPlan with Command {
override def children: Seq[LogicalPlan] = Seq.empty
+
override val output: Seq[Attribute] = Seq(
// Column names are based on Hive.
AttributeReference("col_name", StringType, nullable = false,
@@ -370,7 +43,8 @@ private[sql] case class DescribeCommand(
AttributeReference("data_type", StringType, nullable = false,
new MetadataBuilder().putString("comment", "data type of the column").build())(),
AttributeReference("comment", StringType, nullable = false,
- new MetadataBuilder().putString("comment", "comment of the column").build())())
+ new MetadataBuilder().putString("comment", "comment of the column").build())()
+ )
}
/**
@@ -378,7 +52,7 @@ private[sql] case class DescribeCommand(
* @param allowExisting If it is true, we will do nothing when the table already exists.
* If it is false, an exception will be thrown
*/
-private[sql] case class CreateTableUsing(
+case class CreateTableUsing(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
@@ -397,7 +71,7 @@ private[sql] case class CreateTableUsing(
* can analyze the logical plan that will be used to populate the table.
* So, [[PreWriteCheck]] can detect cases that are not allowed.
*/
-private[sql] case class CreateTableUsingAsSelect(
+case class CreateTableUsingAsSelect(
tableName: String,
provider: String,
temporary: Boolean,
@@ -410,7 +84,7 @@ private[sql] case class CreateTableUsingAsSelect(
// override lazy val resolved = databaseName != None && childrenResolved
}
-private[sql] case class CreateTempTableUsing(
+case class CreateTempTableUsing(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
@@ -425,7 +99,7 @@ private[sql] case class CreateTempTableUsing(
}
}
-private[sql] case class CreateTempTableUsingAsSelect(
+case class CreateTempTableUsingAsSelect(
tableName: String,
provider: String,
partitionColumns: Array[String],
@@ -443,7 +117,7 @@ private[sql] case class CreateTempTableUsingAsSelect(
}
}
-private[sql] case class RefreshTable(tableIdent: TableIdentifier)
+case class RefreshTable(tableIdent: TableIdentifier)
extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
@@ -472,7 +146,7 @@ private[sql] case class RefreshTable(tableIdent: TableIdentifier)
/**
* Builds a map in which keys are case insensitive
*/
-protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
+class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
with Serializable {
val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))
@@ -490,4 +164,4 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St
/**
* The exception thrown from the DDL parser.
*/
-protected[sql] class DDLException(message: String) extends Exception(message)
+class DDLException(message: String) extends RuntimeException(message)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
new file mode 100644
index 0000000000000..6773afc794f9c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala
@@ -0,0 +1,62 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import java.util.Properties
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister}
+
+class DefaultSource extends RelationProvider with DataSourceRegister {
+
+ override def shortName(): String = "jdbc"
+
+ /** Returns a new base relation with the given parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
+ val driver = parameters.getOrElse("driver", null)
+ val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
+ val partitionColumn = parameters.getOrElse("partitionColumn", null)
+ val lowerBound = parameters.getOrElse("lowerBound", null)
+ val upperBound = parameters.getOrElse("upperBound", null)
+ val numPartitions = parameters.getOrElse("numPartitions", null)
+
+ if (driver != null) DriverRegistry.register(driver)
+
+ if (partitionColumn != null
+ && (lowerBound == null || upperBound == null || numPartitions == null)) {
+ sys.error("Partitioning incompletely specified")
+ }
+
+ val partitionInfo = if (partitionColumn == null) {
+ null
+ } else {
+ JDBCPartitioningInfo(
+ partitionColumn,
+ lowerBound.toLong,
+ upperBound.toLong,
+ numPartitions.toInt)
+ }
+ val parts = JDBCRelation.columnPartition(partitionInfo)
+ val properties = new Properties() // Additional properties that we will pass to getConnection
+ parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
+ JDBCRelation(url, table, parts, properties)(sqlContext)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
new file mode 100644
index 0000000000000..7ccd61ed469e9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import java.sql.{Driver, DriverManager}
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * java.sql.DriverManager is always loaded by bootstrap classloader,
+ * so it can't load JDBC drivers accessible by Spark ClassLoader.
+ *
+ * To solve the problem, drivers from user-supplied jars are wrapped into thin wrapper.
+ */
+object DriverRegistry extends Logging {
+
+ private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty
+
+ def register(className: String): Unit = {
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(className)
+ if (cls.getClassLoader == null) {
+ logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required")
+ } else if (wrapperMap.get(className).isDefined) {
+ logTrace(s"Wrapper for $className already exists")
+ } else {
+ synchronized {
+ if (wrapperMap.get(className).isEmpty) {
+ val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
+ DriverManager.registerDriver(wrapper)
+ wrapperMap(className) = wrapper
+ logTrace(s"Wrapper for $className registered")
+ }
+ }
+ }
+ }
+
+ def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
+ case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
+ case driver => driver.getClass.getCanonicalName
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala
new file mode 100644
index 0000000000000..18263fe227d04
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverWrapper.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException}
+import java.util.Properties
+
+/**
+ * A wrapper for a JDBC Driver to work around SPARK-6913.
+ *
+ * The problem is in `java.sql.DriverManager` class that can't access drivers loaded by
+ * Spark ClassLoader.
+ */
+class DriverWrapper(val wrapped: Driver) extends Driver {
+ override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url)
+
+ override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant()
+
+ override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = {
+ wrapped.getPropertyInfo(url, info)
+ }
+
+ override def getMinorVersion: Int = wrapped.getMinorVersion
+
+ def getParentLogger: java.util.logging.Logger = {
+ throw new SQLFeatureNotSupportedException(
+ s"${this.getClass.getName}.getParentLogger is not yet implemented.")
+ }
+
+ override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info)
+
+ override def getMajorVersion: Int = wrapped.getMajorVersion
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
similarity index 98%
rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 3cf70db6b7b09..8eab6a0adccc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.jdbc
+package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -180,9 +181,8 @@ private[sql] object JDBCRDD extends Logging {
try {
if (driver != null) DriverRegistry.register(driver)
} catch {
- case e: ClassNotFoundException => {
- logWarning(s"Couldn't find class $driver", e);
- }
+ case e: ClassNotFoundException =>
+ logWarning(s"Couldn't find class $driver", e)
}
DriverManager.getConnection(url, properties)
}
@@ -344,7 +344,6 @@ private[sql] class JDBCRDD(
}).toArray
}
-
/**
* Runs the SQL query against the JDBC driver.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
similarity index 72%
rename from sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 41d0ecb4bbfbf..f9300dc2cb529 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.jdbc
+package org.apache.spark.sql.execution.datasources.jdbc
import java.util.Properties
@@ -77,42 +77,6 @@ private[sql] object JDBCRelation {
}
}
-private[sql] class DefaultSource extends RelationProvider {
- /** Returns a new base relation with the given parameters. */
- override def createRelation(
- sqlContext: SQLContext,
- parameters: Map[String, String]): BaseRelation = {
- val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
- val driver = parameters.getOrElse("driver", null)
- val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
- val partitionColumn = parameters.getOrElse("partitionColumn", null)
- val lowerBound = parameters.getOrElse("lowerBound", null)
- val upperBound = parameters.getOrElse("upperBound", null)
- val numPartitions = parameters.getOrElse("numPartitions", null)
-
- if (driver != null) DriverRegistry.register(driver)
-
- if (partitionColumn != null
- && (lowerBound == null || upperBound == null || numPartitions == null)) {
- sys.error("Partitioning incompletely specified")
- }
-
- val partitionInfo = if (partitionColumn == null) {
- null
- } else {
- JDBCPartitioningInfo(
- partitionColumn,
- lowerBound.toLong,
- upperBound.toLong,
- numPartitions.toInt)
- }
- val parts = JDBCRelation.columnPartition(partitionInfo)
- val properties = new Properties() // Additional properties that we will pass to getConnection
- parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
- JDBCRelation(url, table, parts, properties)(sqlContext)
- }
-}
-
private[sql] case class JDBCRelation(
url: String,
table: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
new file mode 100644
index 0000000000000..039c13bf163ca
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import java.sql.{Connection, DriverManager, PreparedStatement}
+import java.util.Properties
+
+import scala.util.Try
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.jdbc.JdbcDialects
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Row}
+
+/**
+ * Util functions for JDBC tables.
+ */
+object JdbcUtils extends Logging {
+
+ /**
+ * Establishes a JDBC connection.
+ */
+ def createConnection(url: String, connectionProperties: Properties): Connection = {
+ DriverManager.getConnection(url, connectionProperties)
+ }
+
+ /**
+ * Returns true if the table already exists in the JDBC database.
+ */
+ def tableExists(conn: Connection, table: String): Boolean = {
+ // Somewhat hacky, but there isn't a good way to identify whether a table exists for all
+ // SQL database systems, considering "table" could also include the database name.
+ Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
+ }
+
+ /**
+ * Drops a table from the JDBC database.
+ */
+ def dropTable(conn: Connection, table: String): Unit = {
+ conn.prepareStatement(s"DROP TABLE $table").executeUpdate()
+ }
+
+ /**
+ * Returns a PreparedStatement that inserts a row into table via conn.
+ */
+ def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
+ val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
+ var fieldsLeft = rddSchema.fields.length
+ while (fieldsLeft > 0) {
+ sql.append("?")
+ if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
+ fieldsLeft = fieldsLeft - 1
+ }
+ conn.prepareStatement(sql.toString())
+ }
+
+ /**
+ * Saves a partition of a DataFrame to the JDBC database. This is done in
+ * a single database transaction in order to avoid repeatedly inserting
+ * data as much as possible.
+ *
+ * It is still theoretically possible for rows in a DataFrame to be
+ * inserted into the database more than once if a stage somehow fails after
+ * the commit occurs but before the stage can return successfully.
+ *
+ * This is not a closure inside saveTable() because apparently cosmetic
+ * implementation changes elsewhere might easily render such a closure
+ * non-Serializable. Instead, we explicitly close over all variables that
+ * are used.
+ */
+ def savePartition(
+ getConnection: () => Connection,
+ table: String,
+ iterator: Iterator[Row],
+ rddSchema: StructType,
+ nullTypes: Array[Int]): Iterator[Byte] = {
+ val conn = getConnection()
+ var committed = false
+ try {
+ conn.setAutoCommit(false) // Everything in the same db transaction.
+ val stmt = insertStatement(conn, table, rddSchema)
+ try {
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ val numFields = rddSchema.fields.length
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ stmt.setNull(i + 1, nullTypes(i))
+ } else {
+ rddSchema.fields(i).dataType match {
+ case IntegerType => stmt.setInt(i + 1, row.getInt(i))
+ case LongType => stmt.setLong(i + 1, row.getLong(i))
+ case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
+ case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
+ case ShortType => stmt.setInt(i + 1, row.getShort(i))
+ case ByteType => stmt.setInt(i + 1, row.getByte(i))
+ case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
+ case StringType => stmt.setString(i + 1, row.getString(i))
+ case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
+ case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
+ case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
+ case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
+ case _ => throw new IllegalArgumentException(
+ s"Can't translate non-null value for field $i")
+ }
+ }
+ i = i + 1
+ }
+ stmt.executeUpdate()
+ }
+ } finally {
+ stmt.close()
+ }
+ conn.commit()
+ committed = true
+ } finally {
+ if (!committed) {
+ // The stage must fail. We got here through an exception path, so
+ // let the exception through unless rollback() or close() want to
+ // tell the user about another problem.
+ conn.rollback()
+ conn.close()
+ } else {
+ // The stage must succeed. We cannot propagate any exception close() might throw.
+ try {
+ conn.close()
+ } catch {
+ case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
+ }
+ }
+ }
+ Array[Byte]().iterator
+ }
+
+ /**
+ * Compute the schema string for this RDD.
+ */
+ def schemaString(df: DataFrame, url: String): String = {
+ val sb = new StringBuilder()
+ val dialect = JdbcDialects.get(url)
+ df.schema.fields foreach { field => {
+ val name = field.name
+ val typ: String =
+ dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
+ field.dataType match {
+ case IntegerType => "INTEGER"
+ case LongType => "BIGINT"
+ case DoubleType => "DOUBLE PRECISION"
+ case FloatType => "REAL"
+ case ShortType => "INTEGER"
+ case ByteType => "BYTE"
+ case BooleanType => "BIT(1)"
+ case StringType => "TEXT"
+ case BinaryType => "BLOB"
+ case TimestampType => "TIMESTAMP"
+ case DateType => "DATE"
+ case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})"
+ case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
+ })
+ val nullable = if (field.nullable) "" else "NOT NULL"
+ sb.append(s", $name $typ $nullable")
+ }}
+ if (sb.length < 2) "" else sb.substring(2)
+ }
+
+ /**
+ * Saves the RDD to the database in a single transaction.
+ */
+ def saveTable(
+ df: DataFrame,
+ url: String,
+ table: String,
+ properties: Properties = new Properties()) {
+ val dialect = JdbcDialects.get(url)
+ val nullTypes: Array[Int] = df.schema.fields.map { field =>
+ dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
+ field.dataType match {
+ case IntegerType => java.sql.Types.INTEGER
+ case LongType => java.sql.Types.BIGINT
+ case DoubleType => java.sql.Types.DOUBLE
+ case FloatType => java.sql.Types.REAL
+ case ShortType => java.sql.Types.INTEGER
+ case ByteType => java.sql.Types.INTEGER
+ case BooleanType => java.sql.Types.BIT
+ case StringType => java.sql.Types.CLOB
+ case BinaryType => java.sql.Types.BLOB
+ case TimestampType => java.sql.Types.TIMESTAMP
+ case DateType => java.sql.Types.DATE
+ case t: DecimalType => java.sql.Types.DECIMAL
+ case _ => throw new IllegalArgumentException(
+ s"Can't translate null value for field $field")
+ })
+ }
+
+ val rddSchema = df.schema
+ val driver: String = DriverRegistry.getDriverClassName(url)
+ val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
+ df.foreachPartition { iterator =>
+ savePartition(getConnection, table, iterator, rddSchema, nullTypes)
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
similarity index 89%
rename from sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 04ab5e2217882..b6f3410bad690 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.sql.json
+package org.apache.spark.sql.execution.datasources.json
import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
-import org.apache.spark.sql.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
private[sql] object InferSchema {
@@ -113,8 +113,12 @@ private[sql] object InferSchema {
case INT | LONG => LongType
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
- case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT
- case FLOAT | DOUBLE => DoubleType
+ case BIG_INTEGER | BIG_DECIMAL =>
+ val v = parser.getDecimalValue
+ DecimalType(v.precision(), v.scale())
+ case FLOAT | DOUBLE =>
+ // TODO(davies): Should we use decimal if possible?
+ DoubleType
}
case VALUE_TRUE | VALUE_FALSE => BooleanType
@@ -171,9 +175,18 @@ private[sql] object InferSchema {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
// in most case, also have better precision.
case (DoubleType, t: DecimalType) =>
- if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
+ DoubleType
case (t: DecimalType, DoubleType) =>
- if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType
+ DoubleType
+ case (t1: DecimalType, t2: DecimalType) =>
+ val scale = math.max(t1.scale, t2.scale)
+ val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if (range + scale > 38) {
+ // DecimalType can't support precision > 38
+ DoubleType
+ } else {
+ DecimalType(range + scale, scale)
+ }
case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
new file mode 100644
index 0000000000000..114c8b211891e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import java.io.CharArrayWriter
+
+import com.fasterxml.jackson.core.JsonFactory
+import com.google.common.base.Objects
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
+import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
+import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
+
+import org.apache.spark.Logging
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.PartitionSpec
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.util.SerializableConfiguration
+
+
+class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+
+ override def shortName(): String = "json"
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ dataSchema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): HadoopFsRelation = {
+ val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+
+ new JSONRelation(None, samplingRatio, dataSchema, None, partitionColumns, paths)(sqlContext)
+ }
+}
+
+private[sql] class JSONRelation(
+ val inputRDD: Option[RDD[String]],
+ val samplingRatio: Double,
+ val maybeDataSchema: Option[StructType],
+ val maybePartitionSpec: Option[PartitionSpec],
+ override val userDefinedPartitionColumns: Option[StructType],
+ override val paths: Array[String] = Array.empty[String])(@transient val sqlContext: SQLContext)
+ extends HadoopFsRelation(maybePartitionSpec) {
+
+ /** Constraints to be imposed on schema to be stored. */
+ private def checkConstraints(schema: StructType): Unit = {
+ if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
+ val duplicateColumns = schema.fieldNames.groupBy(identity).collect {
+ case (x, ys) if ys.length > 1 => "\"" + x + "\""
+ }.mkString(", ")
+ throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " +
+ s"cannot save to JSON format")
+ }
+ }
+
+ override val needConversion: Boolean = false
+
+ private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = {
+ val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
+ val conf = job.getConfiguration
+
+ val paths = inputPaths.map(_.getPath)
+
+ if (paths.nonEmpty) {
+ FileInputFormat.setInputPaths(job, paths: _*)
+ }
+
+ sqlContext.sparkContext.hadoopRDD(
+ conf.asInstanceOf[JobConf],
+ classOf[TextInputFormat],
+ classOf[LongWritable],
+ classOf[Text]).map(_._2.toString) // get the text line
+ }
+
+ override lazy val dataSchema = {
+ val jsonSchema = maybeDataSchema.getOrElse {
+ val files = cachedLeafStatuses().filterNot { status =>
+ val name = status.getPath.getName
+ name.startsWith("_") || name.startsWith(".")
+ }.toArray
+ InferSchema(
+ inputRDD.getOrElse(createBaseRdd(files)),
+ samplingRatio,
+ sqlContext.conf.columnNameOfCorruptRecord)
+ }
+ checkConstraints(jsonSchema)
+
+ jsonSchema
+ }
+
+ override private[sql] def buildScan(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ inputPaths: Array[String],
+ broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = {
+ refresh()
+ super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf)
+ }
+
+ override def buildScan(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ inputPaths: Array[FileStatus]): RDD[Row] = {
+ JacksonParser(
+ inputRDD.getOrElse(createBaseRdd(inputPaths)),
+ StructType(requiredColumns.map(dataSchema(_))),
+ sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]]
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case that: JSONRelation =>
+ ((inputRDD, that.inputRDD) match {
+ case (Some(thizRdd), Some(thatRdd)) => thizRdd eq thatRdd
+ case (None, None) => true
+ case _ => false
+ }) && paths.toSet == that.paths.toSet &&
+ dataSchema == that.dataSchema &&
+ schema == that.schema
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ Objects.hashCode(
+ inputRDD,
+ paths.toSet,
+ dataSchema,
+ schema,
+ partitionColumns)
+ }
+
+ override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new JsonOutputWriter(path, dataSchema, context)
+ }
+ }
+ }
+}
+
+private[json] class JsonOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext)
+ extends OutputWriter with SparkHadoopMapRedUtil with Logging {
+
+ val writer = new CharArrayWriter()
+ // create the Generator without separator inserted between 2 records
+ val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+
+ val result = new Text()
+
+ private val recordWriter: RecordWriter[NullWritable, Text] = {
+ new TextOutputFormat[NullWritable, Text]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
+ val split = context.getTaskAttemptID.getTaskID.getId
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
+ }
+ }.getRecordWriter(context)
+ }
+
+ override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
+
+ override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ JacksonGenerator(dataSchema, gen, row)
+ gen.flush()
+
+ result.set(writer.toString)
+ writer.reset()
+
+ recordWriter.write(NullWritable.get(), result)
+ }
+
+ override def close(): Unit = {
+ gen.close()
+ recordWriter.close(context)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala
similarity index 56%
rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala
index 1e6b1198d245b..99ac7730bd1c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala
@@ -15,7 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.sql.json
+package org.apache.spark.sql.execution.datasources.json
+
+import org.apache.spark.sql.catalyst.InternalRow
import scala.collection.Map
@@ -74,4 +76,60 @@ private[sql] object JacksonGenerator {
valWriter(rowSchema, row)
}
+
+ /** Transforms a single InternalRow to JSON using Jackson
+ *
+ * TODO: make the code shared with the other apply method.
+ *
+ * @param rowSchema the schema object used for conversion
+ * @param gen a JsonGenerator object
+ * @param row The row to convert
+ */
+ def apply(rowSchema: StructType, gen: JsonGenerator, row: InternalRow): Unit = {
+ def valWriter: (DataType, Any) => Unit = {
+ case (_, null) | (NullType, _) => gen.writeNull()
+ case (StringType, v) => gen.writeString(v.toString)
+ case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
+ case (IntegerType, v: Int) => gen.writeNumber(v)
+ case (ShortType, v: Short) => gen.writeNumber(v)
+ case (FloatType, v: Float) => gen.writeNumber(v)
+ case (DoubleType, v: Double) => gen.writeNumber(v)
+ case (LongType, v: Long) => gen.writeNumber(v)
+ case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
+ case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
+ case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
+ case (BooleanType, v: Boolean) => gen.writeBoolean(v)
+ case (DateType, v) => gen.writeString(v.toString)
+ case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v))
+
+ case (ArrayType(ty, _), v: ArrayData) =>
+ gen.writeStartArray()
+ v.foreach(ty, (_, value) => valWriter(ty, value))
+ gen.writeEndArray()
+
+ case (MapType(kt, vt, _), v: MapData) =>
+ gen.writeStartObject()
+ v.foreach(kt, vt, { (k, v) =>
+ gen.writeFieldName(k.toString)
+ valWriter(vt, v)
+ })
+ gen.writeEndObject()
+
+ case (StructType(ty), v: InternalRow) =>
+ gen.writeStartObject()
+ var i = 0
+ while (i < ty.length) {
+ val field = ty(i)
+ val value = v.get(i, field.dataType)
+ if (value != null) {
+ gen.writeFieldName(field.name)
+ valWriter(field.dataType, value)
+ }
+ i += 1
+ }
+ gen.writeEndObject()
+ }
+
+ valWriter(rowSchema, row)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
similarity index 89%
rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
index 1c309f8794ef3..cd68bd667c5c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
@@ -15,23 +15,22 @@
* limitations under the License.
*/
-package org.apache.spark.sql.json
+package org.apache.spark.sql.execution.datasources.json
import java.io.ByteArrayOutputStream
-import scala.collection.Map
-
import com.fasterxml.jackson.core._
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-
private[sql] object JacksonParser {
def apply(
json: RDD[String],
@@ -85,9 +84,8 @@ private[sql] object JacksonParser {
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
parser.getDoubleValue
- case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) =>
- // TODO: add fixed precision and scale handling
- Decimal(parser.getDecimalValue)
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, dt: DecimalType) =>
+ Decimal(parser.getDecimalValue, dt.precision, dt.scale)
case (VALUE_NUMBER_INT, ByteType) =>
parser.getByteValue
@@ -127,7 +125,7 @@ private[sql] object JacksonParser {
convertMap(factory, parser, kt)
case (_, udt: UserDefinedType[_]) =>
- udt.deserialize(convertField(factory, parser, udt.sqlType))
+ convertField(factory, parser, udt.sqlType)
}
}
@@ -160,21 +158,21 @@ private[sql] object JacksonParser {
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
- valueType: DataType): Map[UTF8String, Any] = {
- val builder = Map.newBuilder[UTF8String, Any]
+ valueType: DataType): MapData = {
+ val keys = ArrayBuffer.empty[UTF8String]
+ val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
- builder +=
- UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType)
+ keys += UTF8String.fromString(parser.getCurrentName)
+ values += convertField(factory, parser, valueType)
}
-
- builder.result()
+ ArrayBasedMapData(keys.toArray, values.toArray)
}
private def convertArray(
factory: JsonFactory,
parser: JsonParser,
elementType: DataType): ArrayData = {
- val values = scala.collection.mutable.ArrayBuffer.empty[Any]
+ val values = ArrayBuffer.empty[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
values += convertField(factory, parser, elementType)
}
@@ -213,7 +211,7 @@ private[sql] object JacksonParser {
if (array.numElements() == 0) {
Nil
} else {
- array.toArray().map(_.asInstanceOf[InternalRow])
+ array.toArray[InternalRow](schema)
}
case _ =>
sys.error(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala
similarity index 95%
rename from sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala
index fde96852ce68e..005546f37dda0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonUtils.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.json
+package org.apache.spark.sql.execution.datasources.json
import com.fasterxml.jackson.core.{JsonParser, JsonToken}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala
similarity index 99%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala
index 975fec101d9c2..4049795ed3bad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import java.util.{Map => JMap}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala
similarity index 96%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala
index 84f1dccfeb788..ed9e0aa65977b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer}
import org.apache.parquet.schema.MessageType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
similarity index 71%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
index 172db8362afb6..ab5a6ddd41cfc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
@@ -15,17 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import java.math.{BigDecimal, BigInteger}
import java.nio.ByteOrder
import scala.collection.JavaConversions._
-import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.parquet.column.Dictionary
import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
+import org.apache.parquet.schema.OriginalType.LIST
import org.apache.parquet.schema.Type.Repetition
import org.apache.parquet.schema.{GroupType, PrimitiveType, Type}
@@ -42,6 +42,12 @@ import org.apache.spark.unsafe.types.UTF8String
* values to an [[ArrayBuffer]].
*/
private[parquet] trait ParentContainerUpdater {
+ /** Called before a record field is being converted */
+ def start(): Unit = ()
+
+ /** Called after a record field is being converted */
+ def end(): Unit = ()
+
def set(value: Any): Unit = ()
def setBoolean(value: Boolean): Unit = set(value)
def setByte(value: Byte): Unit = set(value)
@@ -55,6 +61,32 @@ private[parquet] trait ParentContainerUpdater {
/** A no-op updater used for root converter (who doesn't have a parent). */
private[parquet] object NoopUpdater extends ParentContainerUpdater
+private[parquet] trait HasParentContainerUpdater {
+ def updater: ParentContainerUpdater
+}
+
+/**
+ * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]].
+ */
+private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater)
+ extends GroupConverter with HasParentContainerUpdater
+
+/**
+ * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types
+ * are handled by this converter. Parquet primitive types are only a subset of those of Spark
+ * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet.
+ */
+private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater)
+ extends PrimitiveConverter with HasParentContainerUpdater {
+
+ override def addBoolean(value: Boolean): Unit = updater.setBoolean(value)
+ override def addInt(value: Int): Unit = updater.setInt(value)
+ override def addLong(value: Long): Unit = updater.setLong(value)
+ override def addFloat(value: Float): Unit = updater.setFloat(value)
+ override def addDouble(value: Double): Unit = updater.setDouble(value)
+ override def addBinary(value: Binary): Unit = updater.set(value.getBytes)
+}
+
/**
* A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s.
* Since any Parquet record is also a struct, this converter can also be used as root converter.
@@ -70,7 +102,7 @@ private[parquet] class CatalystRowConverter(
parquetType: GroupType,
catalystType: StructType,
updater: ParentContainerUpdater)
- extends GroupConverter {
+ extends CatalystGroupConverter(updater) {
/**
* Updater used together with field converters within a [[CatalystRowConverter]]. It propagates
@@ -89,13 +121,11 @@ private[parquet] class CatalystRowConverter(
/**
* Represents the converted row object once an entire Parquet record is converted.
- *
- * @todo Uses [[UnsafeRow]] for better performance.
*/
val currentRow = new SpecificMutableRow(catalystType.map(_.dataType))
// Converters for each field.
- private val fieldConverters: Array[Converter] = {
+ private val fieldConverters: Array[Converter with HasParentContainerUpdater] = {
parquetType.getFields.zip(catalystType).zipWithIndex.map {
case ((parquetFieldType, catalystField), ordinal) =>
// Converted field value should be set to the `ordinal`-th cell of `currentRow`
@@ -105,11 +135,19 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex)
- override def end(): Unit = updater.set(currentRow)
+ override def end(): Unit = {
+ var i = 0
+ while (i < currentRow.numFields) {
+ fieldConverters(i).updater.end()
+ i += 1
+ }
+ updater.set(currentRow)
+ }
override def start(): Unit = {
var i = 0
while (i < currentRow.numFields) {
+ fieldConverters(i).updater.start()
currentRow.setNullAt(i)
i += 1
}
@@ -122,20 +160,20 @@ private[parquet] class CatalystRowConverter(
private def newConverter(
parquetType: Type,
catalystType: DataType,
- updater: ParentContainerUpdater): Converter = {
+ updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = {
catalystType match {
case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType =>
new CatalystPrimitiveConverter(updater)
case ByteType =>
- new PrimitiveConverter {
+ new CatalystPrimitiveConverter(updater) {
override def addInt(value: Int): Unit =
updater.setByte(value.asInstanceOf[ByteType#InternalType])
}
case ShortType =>
- new PrimitiveConverter {
+ new CatalystPrimitiveConverter(updater) {
override def addInt(value: Int): Unit =
updater.setShort(value.asInstanceOf[ShortType#InternalType])
}
@@ -148,7 +186,7 @@ private[parquet] class CatalystRowConverter(
case TimestampType =>
// TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that.
- new PrimitiveConverter {
+ new CatalystPrimitiveConverter(updater) {
// Converts nanosecond timestamps stored as INT96
override def addBinary(value: Binary): Unit = {
assert(
@@ -164,13 +202,23 @@ private[parquet] class CatalystRowConverter(
}
case DateType =>
- new PrimitiveConverter {
+ new CatalystPrimitiveConverter(updater) {
override def addInt(value: Int): Unit = {
// DateType is not specialized in `SpecificMutableRow`, have to box it here.
updater.set(value.asInstanceOf[DateType#InternalType])
}
}
+ // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
+ // annotated by `LIST` or `MAP` should be interpreted as a required list of required
+ // elements where the element type is the type of the field.
+ case t: ArrayType if parquetType.getOriginalType != LIST =>
+ if (parquetType.isPrimitive) {
+ new RepeatedPrimitiveConverter(parquetType, t.elementType, updater)
+ } else {
+ new RepeatedGroupConverter(parquetType, t.elementType, updater)
+ }
+
case t: ArrayType =>
new CatalystArrayConverter(parquetType.asGroupType(), t, updater)
@@ -195,27 +243,11 @@ private[parquet] class CatalystRowConverter(
}
}
- /**
- * Parquet converter for Parquet primitive types. Note that not all Spark SQL atomic types
- * are handled by this converter. Parquet primitive types are only a subset of those of Spark
- * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet.
- */
- private final class CatalystPrimitiveConverter(updater: ParentContainerUpdater)
- extends PrimitiveConverter {
-
- override def addBoolean(value: Boolean): Unit = updater.setBoolean(value)
- override def addInt(value: Int): Unit = updater.setInt(value)
- override def addLong(value: Long): Unit = updater.setLong(value)
- override def addFloat(value: Float): Unit = updater.setFloat(value)
- override def addDouble(value: Double): Unit = updater.setDouble(value)
- override def addBinary(value: Binary): Unit = updater.set(value.getBytes)
- }
-
/**
* Parquet converter for strings. A dictionary is used to minimize string decoding cost.
*/
private final class CatalystStringConverter(updater: ParentContainerUpdater)
- extends PrimitiveConverter {
+ extends CatalystPrimitiveConverter(updater) {
private var expandedDictionary: Array[UTF8String] = null
@@ -242,7 +274,7 @@ private[parquet] class CatalystRowConverter(
private final class CatalystDecimalConverter(
decimalType: DecimalType,
updater: ParentContainerUpdater)
- extends PrimitiveConverter {
+ extends CatalystPrimitiveConverter(updater) {
// Converts decimals stored as INT32
override def addInt(value: Int): Unit = {
@@ -264,7 +296,7 @@ private[parquet] class CatalystRowConverter(
val scale = decimalType.scale
val bytes = value.getBytes
- if (precision <= 8) {
+ if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
var unscaled = 0L
var i = 0
@@ -306,7 +338,7 @@ private[parquet] class CatalystRowConverter(
parquetSchema: GroupType,
catalystSchema: ArrayType,
updater: ParentContainerUpdater)
- extends GroupConverter {
+ extends CatalystGroupConverter(updater) {
private var currentArray: ArrayBuffer[Any] = _
@@ -383,9 +415,10 @@ private[parquet] class CatalystRowConverter(
parquetType: GroupType,
catalystType: MapType,
updater: ParentContainerUpdater)
- extends GroupConverter {
+ extends CatalystGroupConverter(updater) {
- private var currentMap: mutable.Map[Any, Any] = _
+ private var currentKeys: ArrayBuffer[Any] = _
+ private var currentValues: ArrayBuffer[Any] = _
private val keyValueConverter = {
val repeatedType = parquetType.getType(0).asGroupType()
@@ -398,12 +431,16 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = keyValueConverter
- override def end(): Unit = updater.set(currentMap)
+ override def end(): Unit =
+ updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray))
// NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next
// value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row
// cells.
- override def start(): Unit = currentMap = mutable.Map.empty[Any, Any]
+ override def start(): Unit = {
+ currentKeys = ArrayBuffer.empty[Any]
+ currentValues = ArrayBuffer.empty[Any]
+ }
/** Parquet converter for key-value pairs within the map. */
private final class KeyValueConverter(
@@ -430,7 +467,10 @@ private[parquet] class CatalystRowConverter(
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
- override def end(): Unit = currentMap(currentKey) = currentValue
+ override def end(): Unit = {
+ currentKeys += currentKey
+ currentValues += currentValue
+ }
override def start(): Unit = {
currentKey = null
@@ -438,4 +478,61 @@ private[parquet] class CatalystRowConverter(
}
}
}
+
+ private trait RepeatedConverter {
+ private var currentArray: ArrayBuffer[Any] = _
+
+ protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater {
+ override def start(): Unit = currentArray = ArrayBuffer.empty[Any]
+ override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray))
+ override def set(value: Any): Unit = currentArray += value
+ }
+ }
+
+ /**
+ * A primitive converter for converting unannotated repeated primitive values to required arrays
+ * of required primitives values.
+ */
+ private final class RepeatedPrimitiveConverter(
+ parquetType: Type,
+ catalystType: DataType,
+ parentUpdater: ParentContainerUpdater)
+ extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater {
+
+ val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater)
+
+ private val elementConverter: PrimitiveConverter =
+ newConverter(parquetType, catalystType, updater).asPrimitiveConverter()
+
+ override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value)
+ override def addInt(value: Int): Unit = elementConverter.addInt(value)
+ override def addLong(value: Long): Unit = elementConverter.addLong(value)
+ override def addFloat(value: Float): Unit = elementConverter.addFloat(value)
+ override def addDouble(value: Double): Unit = elementConverter.addDouble(value)
+ override def addBinary(value: Binary): Unit = elementConverter.addBinary(value)
+
+ override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict)
+ override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport
+ override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id)
+ }
+
+ /**
+ * A group converter for converting unannotated repeated group values to required arrays of
+ * required struct values.
+ */
+ private final class RepeatedGroupConverter(
+ parquetType: Type,
+ catalystType: DataType,
+ parentUpdater: ParentContainerUpdater)
+ extends GroupConverter with HasParentContainerUpdater with RepeatedConverter {
+
+ val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater)
+
+ private val elementConverter: GroupConverter =
+ newConverter(parquetType, catalystType, updater).asGroupConverter()
+
+ override def getConverter(field: Int): Converter = elementConverter.getConverter(field)
+ override def end(): Unit = elementConverter.end()
+ override def start(): Unit = elementConverter.start()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
similarity index 95%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
index d43ca95b4eea0..275646e8181ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import scala.collection.JavaConversions._
@@ -25,6 +25,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.Type.Repetition._
import org.apache.parquet.schema._
+import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, SQLConf}
@@ -99,8 +100,11 @@ private[parquet] class CatalystSchemaConverter(
StructField(field.getName, convertField(field), nullable = false)
case REPEATED =>
- throw new AnalysisException(
- s"REPEATED not supported outside LIST or MAP. Type: $field")
+ // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
+ // annotated by `LIST` or `MAP` should be interpreted as a required list of required
+ // elements where the element type is the type of the field.
+ val arrayType = ArrayType(convertField(field), containsNull = false)
+ StructField(field.getName, arrayType, nullable = false)
}
}
@@ -155,7 +159,7 @@ private[parquet] class CatalystSchemaConverter(
case INT_16 => ShortType
case INT_32 | null => IntegerType
case DATE => DateType
- case DECIMAL => makeDecimalType(maxPrecisionForBytes(4))
+ case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32)
case TIME_MILLIS => typeNotImplemented()
case _ => illegalType()
}
@@ -163,7 +167,7 @@ private[parquet] class CatalystSchemaConverter(
case INT64 =>
originalType match {
case INT_64 | null => LongType
- case DECIMAL => makeDecimalType(maxPrecisionForBytes(8))
+ case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64)
case TIMESTAMP_MILLIS => typeNotImplemented()
case _ => illegalType()
}
@@ -405,7 +409,7 @@ private[parquet] class CatalystSchemaConverter(
// Uses INT32 for 1 <= precision <= 9
case DecimalType.Fixed(precision, scale)
- if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec =>
+ if precision <= MAX_PRECISION_FOR_INT32 && followParquetFormatSpec =>
Types
.primitive(INT32, repetition)
.as(DECIMAL)
@@ -415,7 +419,7 @@ private[parquet] class CatalystSchemaConverter(
// Uses INT64 for 1 <= precision <= 18
case DecimalType.Fixed(precision, scale)
- if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec =>
+ if precision <= MAX_PRECISION_FOR_INT64 && followParquetFormatSpec =>
Types
.primitive(INT64, repetition)
.as(DECIMAL)
@@ -534,14 +538,6 @@ private[parquet] class CatalystSchemaConverter(
throw new AnalysisException(s"Unsupported data type $field.dataType")
}
}
-
- // Max precision of a decimal value stored in `numBytes` bytes
- private def maxPrecisionForBytes(numBytes: Int): Int = {
- Math.round( // convert double to long
- Math.floor(Math.log10( // number of base-10 digits
- Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
- .asInstanceOf[Int]
- }
}
@@ -584,4 +580,16 @@ private[parquet] object CatalystSchemaConverter {
computeMinBytesForPrecision(precision)
}
}
+
+ val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4)
+
+ val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8)
+
+ // Max precision of a decimal value stored in `numBytes` bytes
+ def maxPrecisionForBytes(numBytes: Int): Int = {
+ Math.round( // convert double to long
+ Math.floor(Math.log10( // number of base-10 digits
+ Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
+ .asInstanceOf[Int]
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala
similarity index 98%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala
index 1551afd7b7bf2..2c6b914328b60 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala
similarity index 87%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala
index 2332a36468dbc..ccd7ebf319af9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetConverter.scala
@@ -15,10 +15,10 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.ArrayData
+import org.apache.spark.sql.types.{MapData, ArrayData}
// TODO Removes this while fixing SPARK-8848
private[sql] object CatalystConverter {
@@ -33,7 +33,7 @@ private[sql] object CatalystConverter {
val MAP_SCHEMA_NAME = "map"
// TODO: consider using Array[T] for arrays to avoid boxing of primitive types
- type ArrayScalaType[T] = ArrayData
- type StructScalaType[T] = InternalRow
- type MapScalaType[K, V] = Map[K, V]
+ type ArrayScalaType = ArrayData
+ type StructScalaType = InternalRow
+ type MapScalaType = MapData
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
similarity index 90%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index d57b789f5c1c7..63915e0a28655 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import java.io.Serializable
import java.nio.ByteBuffer
@@ -25,9 +25,10 @@ import org.apache.hadoop.conf.Configuration
import org.apache.parquet.filter2.compat.FilterCompat
import org.apache.parquet.filter2.compat.FilterCompat._
import org.apache.parquet.filter2.predicate.FilterApi._
-import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Statistics}
-import org.apache.parquet.filter2.predicate.UserDefinedPredicate
+import org.apache.parquet.filter2.predicate._
import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.OriginalType
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
@@ -197,6 +198,8 @@ private[sql] object ParquetFilters {
def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = {
val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap
+ relaxParquetValidTypeMap
+
// NOTE:
//
// For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`,
@@ -239,6 +242,37 @@ private[sql] object ParquetFilters {
}
}
+ // !! HACK ALERT !!
+ //
+ // This lazy val is a workaround for PARQUET-201, and should be removed once we upgrade to
+ // parquet-mr 1.8.1 or higher versions.
+ //
+ // In Parquet, not all types of columns can be used for filter push-down optimization. The set
+ // of valid column types is controlled by `ValidTypeMap`. Unfortunately, in parquet-mr 1.7.0 and
+ // prior versions, the limitation is too strict, and doesn't allow `BINARY (ENUM)` columns to be
+ // pushed down.
+ //
+ // This restriction is problematic for Spark SQL, because Spark SQL doesn't have a type that maps
+ // to Parquet original type `ENUM` directly, and always converts `ENUM` to `StringType`. Thus,
+ // a predicate involving a `ENUM` field can be pushed-down as a string column, which is perfectly
+ // legal except that it fails the `ValidTypeMap` check.
+ //
+ // Here we add `BINARY (ENUM)` into `ValidTypeMap` lazily via reflection to workaround this issue.
+ private lazy val relaxParquetValidTypeMap: Unit = {
+ val constructor = Class
+ .forName(classOf[ValidTypeMap].getCanonicalName + "$FullTypeDescriptor")
+ .getDeclaredConstructor(classOf[PrimitiveTypeName], classOf[OriginalType])
+
+ constructor.setAccessible(true)
+ val enumTypeDescriptor = constructor
+ .newInstance(PrimitiveTypeName.BINARY, OriginalType.ENUM)
+ .asInstanceOf[AnyRef]
+
+ val addMethod = classOf[ValidTypeMap].getDeclaredMethods.find(_.getName == "add").get
+ addMethod.setAccessible(true)
+ addMethod.invoke(null, classOf[Binary], enumTypeDescriptor)
+ }
+
/**
* Converts Catalyst predicate expressions to Parquet filter predicates.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
similarity index 97%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index b4337a48dbd80..52fac18ba187a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import java.net.URI
import java.util.logging.{Level, Logger => JLogger}
@@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
-private[sql] class DefaultSource extends HadoopFsRelationProvider {
+private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+
+ override def shortName(): String = "parquet"
+
override def createRelation(
sqlContext: SQLContext,
paths: Array[String],
@@ -62,7 +65,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider {
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext)
- extends OutputWriterInternal {
+ extends OutputWriter {
private val recordWriter: RecordWriter[Void, InternalRow] = {
val outputFormat = {
@@ -87,7 +90,9 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext
outputFormat.getRecordWriter(context)
}
- override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
+ override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
+
+ override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
override def close(): Unit = recordWriter.close(context)
}
@@ -204,6 +209,13 @@ private[sql] class ParquetRelation(
override def prepareJobForWrite(job: Job): OutputWriterFactory = {
val conf = ContextUtil.getConfiguration(job)
+ // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible
+ val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key)
+ if (committerClassname == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") {
+ conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key,
+ classOf[DirectParquetOutputCommitter].getCanonicalName)
+ }
+
val committerClass =
conf.getClass(
SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key,
@@ -291,7 +303,6 @@ private[sql] class ParquetRelation(
initDriverSideJobFuncOpt = Some(setInputPaths),
initLocalJobFuncOpt = Some(initLocalJobFuncOpt),
inputFormatClass = classOf[ParquetInputFormat[InternalRow]],
- keyClass = classOf[Void],
valueClass = classOf[InternalRow]) {
val cacheMetadata = useMetadataCache
@@ -328,7 +339,7 @@ private[sql] class ParquetRelation(
new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
}
}
- }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row]
+ }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row]
}
}
@@ -667,7 +678,7 @@ private[sql] object ParquetRelation extends Logging {
val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec
val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration)
- // HACK ALERT:
+ // !! HACK ALERT !!
//
// Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es
// to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable`
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala
similarity index 93%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala
index 79dd16b7b0c39..3191cf3d121bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import java.math.BigInteger
import java.nio.{ByteBuffer, ByteOrder}
@@ -88,13 +88,13 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case t @ ArrayType(_, _) => writeArray(
t,
- value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
+ value.asInstanceOf[CatalystConverter.ArrayScalaType])
case t @ MapType(_, _, _) => writeMap(
t,
- value.asInstanceOf[CatalystConverter.MapScalaType[_, _]])
+ value.asInstanceOf[CatalystConverter.MapScalaType])
case t @ StructType(_) => writeStruct(
t,
- value.asInstanceOf[CatalystConverter.StructScalaType[_]])
+ value.asInstanceOf[CatalystConverter.StructScalaType])
case _ => writePrimitive(schema.asInstanceOf[AtomicType], value)
}
}
@@ -124,7 +124,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeStruct(
schema: StructType,
- struct: CatalystConverter.StructScalaType[_]): Unit = {
+ struct: CatalystConverter.StructScalaType): Unit = {
if (struct != null) {
val fields = schema.fields.toArray
writer.startGroup()
@@ -143,7 +143,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeArray(
schema: ArrayType,
- array: CatalystConverter.ArrayScalaType[_]): Unit = {
+ array: CatalystConverter.ArrayScalaType): Unit = {
val elementType = schema.elementType
writer.startGroup()
if (array.numElements() > 0) {
@@ -154,7 +154,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.startGroup()
if (!array.isNullAt(i)) {
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
- writeValue(elementType, array.get(i))
+ writeValue(elementType, array.get(i, elementType))
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -165,7 +165,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
var i = 0
while (i < array.numElements()) {
- writeValue(elementType, array.get(i))
+ writeValue(elementType, array.get(i, elementType))
i = i + 1
}
writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0)
@@ -176,11 +176,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writeMap(
schema: MapType,
- map: CatalystConverter.MapScalaType[_, _]): Unit = {
+ map: CatalystConverter.MapScalaType): Unit = {
writer.startGroup()
- if (map.size > 0) {
+ val length = map.numElements()
+ if (length > 0) {
writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0)
- for ((key, value) <- map) {
+ map.foreach(schema.keyType, schema.valueType, (key, value) => {
writer.startGroup()
writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0)
writeValue(schema.keyType, key)
@@ -191,7 +192,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1)
}
writer.endGroup()
- }
+ })
writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0)
}
writer.endGroup()
@@ -293,8 +294,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
case BinaryType =>
writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
- case DecimalType.Fixed(precision, _) =>
- writeDecimal(record.getDecimal(index), precision)
+ case DecimalType.Fixed(precision, scale) =>
+ writeDecimal(record.getDecimal(index, precision, scale), precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala
similarity index 99%
rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala
index 3854f5bd39fb1..019db34fc666d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.parquet
+package org.apache.spark.sql.execution.datasources.parquet
import java.io.IOException
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 11bb49b8d83de..40ca8bf4095d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -101,7 +101,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
}
}
- case logical.InsertIntoTable(LogicalRelation(r: HadoopFsRelation), part, _, _, _) =>
+ case logical.InsertIntoTable(
+ LogicalRelation(r: HadoopFsRelation), part, query, overwrite, _) =>
// We need to make sure the partition columns specified by users do match partition
// columns of the relation.
val existingPartitionColumns = r.partitionColumns.fieldNames.toSet
@@ -115,6 +116,17 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
+ // Get all input data source relations of the query.
+ val srcRelations = query.collect {
+ case LogicalRelation(src: BaseRelation) => src
+ }
+ if (srcRelations.contains(r)) {
+ failAnalysis(
+ "Cannot insert overwrite into table that is also being read from.")
+ } else {
+ // OK
+ }
+
case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) =>
// The relation in l is not an InsertableRelation.
failAnalysis(s"$l does not allow insertion.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index f26f41fb75d57..74892e4e13fa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -17,21 +17,16 @@
package org.apache.spark.sql.execution
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.unsafe.types.UTF8String
-
import scala.collection.mutable.HashSet
-import org.apache.spark.{AccumulatorParam, Accumulator, Logging}
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.types._
+import org.apache.spark.{Accumulator, AccumulatorParam, Logging}
/**
- * :: DeveloperApi ::
* Contains methods for debugging query execution.
*
* Usage:
@@ -53,10 +48,8 @@ package object debug {
}
/**
- * :: DeveloperApi ::
* Augments [[DataFrame]]s with debug methods.
*/
- @DeveloperApi
implicit class DebugQuery(query: DataFrame) extends Logging {
def debug(): Unit = {
val plan = query.queryExecution.executedPlan
@@ -72,23 +65,6 @@ package object debug {
case _ =>
}
}
-
- def typeCheck(): Unit = {
- val plan = query.queryExecution.executedPlan
- val visited = new collection.mutable.HashSet[TreeNodeRef]()
- val debugPlan = plan transform {
- case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) =>
- visited += new TreeNodeRef(s)
- TypeCheck(s)
- }
- try {
- logDebug(s"Results returned: ${debugPlan.execute().count()}")
- } catch {
- case e: Exception =>
- def unwrap(e: Throwable): Throwable = if (e.getCause == null) e else unwrap(e.getCause)
- logDebug(s"Deepest Error: ${unwrap(e)}")
- }
- }
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
@@ -148,70 +124,4 @@ package object debug {
}
}
}
-
- /**
- * Helper functions for checking that runtime types match a given schema.
- */
- private[sql] object TypeCheck {
- def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match {
- case (null, _) =>
-
- case (row: InternalRow, StructType(fields)) =>
- row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) }
- case (a: ArrayData, ArrayType(elemType, _)) =>
- a.toArray().foreach(typeCheck(_, elemType))
- case (m: Map[_, _], MapType(keyType, valueType, _)) =>
- m.keys.foreach(typeCheck(_, keyType))
- m.values.foreach(typeCheck(_, valueType))
-
- case (_: Long, LongType) =>
- case (_: Int, IntegerType) =>
- case (_: UTF8String, StringType) =>
- case (_: Float, FloatType) =>
- case (_: Byte, ByteType) =>
- case (_: Short, ShortType) =>
- case (_: Boolean, BooleanType) =>
- case (_: Double, DoubleType) =>
- case (_: Int, DateType) =>
- case (_: Long, TimestampType) =>
- case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType)
-
- case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t")
- }
- }
-
- /**
- * Augments [[DataFrame]]s with debug methods.
- */
- private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan {
- import TypeCheck._
-
- override def nodeName: String = ""
-
- /* Only required when defining this class in a REPL.
- override def makeCopy(args: Array[Object]): this.type =
- TypeCheck(args(0).asInstanceOf[SparkPlan]).asInstanceOf[this.type]
- */
-
- def output: Seq[Attribute] = child.output
-
- def children: List[SparkPlan] = child :: Nil
-
- protected override def doExecute(): RDD[InternalRow] = {
- child.execute().map { row =>
- try typeCheck(row, child.schema) catch {
- case e: Exception =>
- sys.error(
- s"""
- |ERROR WHEN TYPE CHECKING QUERY
- |==============================
- |$e
- |======== BAD TREE ============
- |$child
- """.stripMargin)
- }
- row
- }
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 624efc1b1d734..2e108cb814516 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -25,8 +25,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils
+import org.apache.spark.{InternalAccumulator, TaskContext}
/**
* :: DeveloperApi ::
@@ -44,6 +46,11 @@ case class BroadcastHashJoin(
right: SparkPlan)
extends BinaryNode with HashJoin {
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
val timeout: Duration = {
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
@@ -58,25 +65,65 @@ case class BroadcastHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+ // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
+ // for the same query.
@transient
- private val broadcastFuture = future {
- // Note that we use .execute().collect() because we don't want to convert data to Scala types
- val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
- sparkContext.broadcast(hashed)
- }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
+ private lazy val broadcastFuture = {
+ val numBuildRows = buildSide match {
+ case BuildLeft => longMetric("numLeftRows")
+ case BuildRight => longMetric("numRightRows")
+ }
+
+ // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ future {
+ // This will run in another thread. Set the execution id so that we can connect these jobs
+ // with the correct execution.
+ SQLExecution.withExecutionId(sparkContext, executionId) {
+ // Note that we use .execute().collect() because we don't want to convert data to Scala
+ // types
+ val input: Array[InternalRow] = buildPlan.execute().map { row =>
+ numBuildRows += 1
+ row.copy()
+ }.collect()
+ // The following line doesn't run in a job so we cannot track the metric value. However, we
+ // have already tracked it in the above lines. So here we can use
+ // `SQLMetrics.nullLongMetric` to ignore it.
+ val hashed = HashedRelation(
+ input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+ sparkContext.broadcast(hashed)
+ }
+ }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
+ }
+
+ protected override def doPrepare(): Unit = {
+ broadcastFuture
+ }
protected override def doExecute(): RDD[InternalRow] = {
+ val numStreamedRows = buildSide match {
+ case BuildLeft => longMetric("numRightRows")
+ case BuildRight => longMetric("numLeftRows")
+ }
+ val numOutputRows = longMetric("numOutputRows")
+
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
- hashJoin(streamedIter, broadcastRelation.value)
+ val hashedRelation = broadcastRelation.value
+ hashedRelation match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+ hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}
}
object BroadcastHashJoin {
- private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
+ private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 77e7fe71009b7..69a8b95eaa7ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -24,10 +24,11 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.{InternalAccumulator, TaskContext}
/**
* :: DeveloperApi ::
@@ -45,6 +46,11 @@ case class BroadcastHashOuterJoin(
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashOuterJoin {
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
val timeout = {
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
@@ -57,15 +63,56 @@ case class BroadcastHashOuterJoin(
override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
+ // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
+ // for the same query.
@transient
- private val broadcastFuture = future {
- // Note that we use .execute().collect() because we don't want to convert data to Scala types
- val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
- val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
- sparkContext.broadcast(hashed)
- }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
+ private lazy val broadcastFuture = {
+ val numBuildRows = joinType match {
+ case RightOuter => longMetric("numLeftRows")
+ case LeftOuter => longMetric("numRightRows")
+ case x =>
+ throw new IllegalArgumentException(
+ s"HashOuterJoin should not take $x as the JoinType")
+ }
+
+ // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ future {
+ // This will run in another thread. Set the execution id so that we can connect these jobs
+ // with the correct execution.
+ SQLExecution.withExecutionId(sparkContext, executionId) {
+ // Note that we use .execute().collect() because we don't want to convert data to Scala
+ // types
+ val input: Array[InternalRow] = buildPlan.execute().map { row =>
+ numBuildRows += 1
+ row.copy()
+ }.collect()
+ // The following line doesn't run in a job so we cannot track the metric value. However, we
+ // have already tracked it in the above lines. So here we can use
+ // `SQLMetrics.nullLongMetric` to ignore it.
+ val hashed = HashedRelation(
+ input.iterator, SQLMetrics.nullLongMetric, buildKeyGenerator, input.size)
+ sparkContext.broadcast(hashed)
+ }
+ }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
+ }
+
+ protected override def doPrepare(): Unit = {
+ broadcastFuture
+ }
override def doExecute(): RDD[InternalRow] = {
+ val numStreamedRows = joinType match {
+ case RightOuter => longMetric("numRightRows")
+ case LeftOuter => longMetric("numLeftRows")
+ case x =>
+ throw new IllegalArgumentException(
+ s"HashOuterJoin should not take $x as the JoinType")
+ }
+ val numOutputRows = longMetric("numOutputRows")
+
val broadcastRelation = Await.result(broadcastFuture, timeout)
streamedPlan.execute().mapPartitions { streamedIter =>
@@ -73,19 +120,29 @@ case class BroadcastHashOuterJoin(
val hashTable = broadcastRelation.value
val keyGenerator = streamedKeyGenerator
+ hashTable match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+
+ val resultProj = resultProjection
joinType match {
case LeftOuter =>
streamedIter.flatMap(currentRow => {
+ numStreamedRows += 1
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey))
+ leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
})
case RightOuter =>
streamedIter.flatMap(currentRow => {
+ numStreamedRows += 1
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow)
+ rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
})
case x =>
@@ -95,9 +152,3 @@ case class BroadcastHashOuterJoin(
}
}
}
-
-object BroadcastHashOuterJoin {
-
- private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128))
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index a60593911f94f..78a8c16c62bca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -36,22 +38,42 @@ case class BroadcastLeftSemiJoinHash(
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
protected override def doExecute(): RDD[InternalRow] = {
- val input = right.execute().map(_.copy()).collect()
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
+
+ val input = right.execute().map { row =>
+ numRightRows += 1
+ row.copy()
+ }.collect()
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(input.toIterator)
+ val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric)
val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
- hashSemiJoin(streamIter, broadcastedRelation.value)
+ hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows)
}
} else {
- val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size)
+ val hashRelation =
+ HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size)
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>
- hashSemiJoin(streamIter, broadcastedRelation.value)
+ val hashedRelation = broadcastedRelation.value
+ hashedRelation match {
+ case unsafe: UnsafeHashedRelation =>
+ TaskContext.get().internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ case _ =>
+ }
+ hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 83b726a8e2897..28c88b1b03d02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.CompactBuffer
/**
@@ -38,6 +39,11 @@ case class BroadcastNestedLoopJoin(
condition: Option[Expression]) extends BinaryNode {
// TODO: Override requiredChildDistribution.
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
/** BuildRight means the right relation <=> the broadcast relation. */
private val (streamed, broadcast) = buildSide match {
case BuildRight => (left, right)
@@ -47,7 +53,7 @@ case class BroadcastNestedLoopJoin(
override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true
- @transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
+ private[this] def genResultProjection: InternalRow => InternalRow = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
@@ -65,8 +71,9 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case _ =>
- left.output ++ right.output
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastNestedLoopJoin should not take $x as the JoinType")
}
}
@@ -74,9 +81,17 @@ case class BroadcastNestedLoopJoin(
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
protected override def doExecute(): RDD[InternalRow] = {
+ val (numStreamedRows, numBuildRows) = buildSide match {
+ case BuildRight => (longMetric("numLeftRows"), longMetric("numRightRows"))
+ case BuildLeft => (longMetric("numRightRows"), longMetric("numLeftRows"))
+ }
+ val numOutputRows = longMetric("numOutputRows")
+
val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy())
- .collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map { row =>
+ numBuildRows += 1
+ row.copy()
+ }.collect().toIndexedSeq)
/** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
@@ -88,20 +103,22 @@ case class BroadcastNestedLoopJoin(
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
+ val resultProj = genResultProjection
streamedIter.foreach { streamedRow =>
var i = 0
var streamRowMatched = false
+ numStreamedRows += 1
while (i < broadcastedRelation.value.size) {
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
- matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy()
+ matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
- matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy()
+ matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy()
streamRowMatched = true
includedBroadcastTuples += i
case _ =>
@@ -111,9 +128,9 @@ case class BroadcastNestedLoopJoin(
(streamRowMatched, joinType, buildSide) match {
case (false, LeftOuter | FullOuter, BuildRight) =>
- matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy()
+ matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
case (false, RightOuter | FullOuter, BuildLeft) =>
- matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy()
+ matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
case _ =>
}
}
@@ -127,6 +144,8 @@ case class BroadcastNestedLoopJoin(
val leftNulls = new GenericMutableRow(left.output.size)
val rightNulls = new GenericMutableRow(right.output.size)
+ val resultProj = genResultProjection
+
/** Rows from broadcasted joined with nulls. */
val broadcastRowsWithNulls: Seq[InternalRow] = {
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
@@ -138,7 +157,7 @@ case class BroadcastNestedLoopJoin(
joinedRow.withLeft(leftNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
- buf += resultProjection(joinedRow.withRight(rel(i))).copy()
+ buf += resultProj(joinedRow.withRight(rel(i))).copy()
}
i += 1
}
@@ -147,7 +166,7 @@ case class BroadcastNestedLoopJoin(
joinedRow.withRight(rightNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
- buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
+ buf += resultProj(joinedRow.withLeft(rel(i))).copy()
}
i += 1
}
@@ -158,6 +177,12 @@ case class BroadcastNestedLoopJoin(
// TODO: Breaks lineage.
sparkContext.union(
- matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
+ matchesOrStreamedRowsWithNulls.flatMap(_._1),
+ sparkContext.makeRDD(broadcastRowsWithNulls)
+ ).map { row =>
+ // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here.
+ numOutputRows += 1
+ row
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index 261b4724159fb..2115f40702286 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -30,13 +31,31 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
protected override def doExecute(): RDD[InternalRow] = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
+
+ val leftResults = left.execute().map { row =>
+ numLeftRows += 1
+ row.copy()
+ }
+ val rightResults = right.execute().map { row =>
+ numRightRows += 1
+ row.copy()
+ }
leftResults.cartesian(rightResults).mapPartitions { iter =>
val joinedRow = new JoinedRow
- iter.map(r => joinedRow(r._1, r._2))
+ iter.map { r =>
+ numOutputRows += 1
+ joinedRow(r._1, r._2)
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 6b3d1652923fd..7ce4a517838cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.LongSQLMetric
trait HashJoin {
@@ -44,7 +45,8 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
protected[this] def isUnsafeMode: Boolean = {
- (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
+ (self.codegenEnabled && self.unsafeEnabled
+ && UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}
@@ -52,14 +54,14 @@ trait HashJoin {
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode
- @transient protected lazy val buildSideKeyGenerator: Projection =
+ protected def buildSideKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}
- @transient protected lazy val streamSideKeyGenerator: Projection =
+ protected def streamSideKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
@@ -68,7 +70,9 @@ trait HashJoin {
protected def hashJoin(
streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation): Iterator[InternalRow] =
+ numStreamRows: LongSQLMetric,
+ hashedRelation: HashedRelation,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] =
{
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
@@ -97,6 +101,7 @@ trait HashJoin {
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
currentMatchPosition += 1
+ numOutputRows += 1
resultProjection(ret)
}
@@ -112,6 +117,7 @@ trait HashJoin {
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
+ numStreamRows += 1
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashedRelation.get(key)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 7e671e7914f1a..66903347c88c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -22,9 +22,9 @@ import java.util.{HashMap => JavaHashMap}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.LongSQLMetric
import org.apache.spark.util.collection.CompactBuffer
@DeveloperApi
@@ -38,14 +38,6 @@ trait HashOuterJoin {
val left: SparkPlan
val right: SparkPlan
- override def outputPartitioning: Partitioning = joinType match {
- case LeftOuter => left.outputPartitioning
- case RightOuter => right.outputPartitioning
- case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
- case x =>
- throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
- }
-
override def output: Seq[Attribute] = {
joinType match {
case LeftOuter =>
@@ -76,7 +68,7 @@ trait HashOuterJoin {
}
protected[this] def isUnsafeMode: Boolean = {
- (self.codegenEnabled && joinType != FullOuter
+ (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}
@@ -85,14 +77,14 @@ trait HashOuterJoin {
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode
- @transient protected lazy val buildKeyGenerator: Projection =
+ protected def buildKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}
- @transient protected[this] lazy val streamedKeyGenerator: Projection = {
+ protected[this] def streamedKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
@@ -100,7 +92,7 @@ trait HashOuterJoin {
}
}
- @transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
+ protected[this] def resultProjection: InternalRow => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
@@ -122,23 +114,30 @@ trait HashOuterJoin {
protected[this] def leftOuterIterator(
key: InternalRow,
joinedRow: JoinedRow,
- rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
+ rightIter: Iterable[InternalRow],
+ resultProjection: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = if (rightIter != null) {
rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy()
+ case r if boundCondition(joinedRow.withRight(r)) => {
+ numOutputRows += 1
+ resultProjection(joinedRow).copy()
+ }
}
} else {
List.empty
}
if (temp.isEmpty) {
- resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
+ numOutputRows += 1
+ resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
} else {
temp
}
} else {
- resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil
+ numOutputRows += 1
+ resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
}
}
ret.iterator
@@ -147,24 +146,30 @@ trait HashOuterJoin {
protected[this] def rightOuterIterator(
key: InternalRow,
leftIter: Iterable[InternalRow],
- joinedRow: JoinedRow): Iterator[InternalRow] = {
+ joinedRow: JoinedRow,
+ resultProjection: InternalRow => InternalRow,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val ret: Iterable[InternalRow] = {
if (!key.anyNull) {
val temp = if (leftIter != null) {
leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) =>
+ case l if boundCondition(joinedRow.withLeft(l)) => {
+ numOutputRows += 1
resultProjection(joinedRow).copy()
+ }
}
} else {
List.empty
}
if (temp.isEmpty) {
- resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
+ numOutputRows += 1
+ resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
} else {
temp
}
} else {
- resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil
+ numOutputRows += 1
+ resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
}
}
ret.iterator
@@ -172,7 +177,7 @@ trait HashOuterJoin {
protected[this] def fullOuterIterator(
key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow],
- joinedRow: JoinedRow): Iterator[InternalRow] = {
+ joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
if (!key.anyNull) {
// Store the positions of records in right, if one of its associated row satisfy
// the join condition.
@@ -185,6 +190,7 @@ trait HashOuterJoin {
// append them directly
case (r, idx) if boundCondition(joinedRow.withRight(r)) =>
+ numOutputRows += 1
matched = true
// if the row satisfy the join condition, add its index into the matched set
rightMatchedSet.add(idx)
@@ -197,6 +203,7 @@ trait HashOuterJoin {
// as we don't know whether we need to append it until finish iterating all
// of the records in right side.
// If we didn't get any proper row, then append a single row with empty right.
+ numOutputRows += 1
joinedRow.withRight(rightNullRow).copy()
})
} ++ rightIter.zipWithIndex.collect {
@@ -205,12 +212,15 @@ trait HashOuterJoin {
// Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set.
case (r, idx) if !rightMatchedSet.contains(idx) =>
+ numOutputRows += 1
joinedRow(leftNullRow, r).copy()
}
} else {
leftIter.iterator.map[InternalRow] { l =>
+ numOutputRows += 1
joinedRow(l, rightNullRow).copy()
} ++ rightIter.iterator.map[InternalRow] { r =>
+ numOutputRows += 1
joinedRow(leftNullRow, r).copy()
}
}
@@ -219,10 +229,12 @@ trait HashOuterJoin {
// This is only used by FullOuter
protected[this] def buildHashTable(
iter: Iterator[InternalRow],
+ numIterRows: LongSQLMetric,
keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]()
while (iter.hasNext) {
val currentRow = iter.next()
+ numIterRows += 1
val rowKey = keyGenerator(currentRow)
var existingMatchList = hashTable.get(rowKey)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 97fde8f975bfd..beb141ade616d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.metric.LongSQLMetric
trait HashSemiJoin {
@@ -33,7 +34,8 @@ trait HashSemiJoin {
override def output: Seq[Attribute] = left.output
protected[this] def supportUnsafe: Boolean = {
- (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
+ (self.codegenEnabled && self.unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
&& UnsafeProjection.canSupport(rightKeys)
&& UnsafeProjection.canSupport(left.schema)
&& UnsafeProjection.canSupport(right.schema))
@@ -43,14 +45,14 @@ trait HashSemiJoin {
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def canProcessSafeRows: Boolean = !supportUnsafe
- @transient protected lazy val leftKeyGenerator: Projection =
+ protected def leftKeyGenerator: Projection =
if (supportUnsafe) {
UnsafeProjection.create(leftKeys, left.output)
} else {
newMutableProjection(leftKeys, left.output)()
}
- @transient protected lazy val rightKeyGenerator: Projection =
+ protected def rightKeyGenerator: Projection =
if (supportUnsafe) {
UnsafeProjection.create(rightKeys, right.output)
} else {
@@ -60,14 +62,15 @@ trait HashSemiJoin {
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
- protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
+ protected def buildKeyHashSet(
+ buildIter: Iterator[InternalRow], numBuildRows: LongSQLMetric): java.util.Set[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
- var currentRow: InternalRow = null
// Create a Hash set of buildKeys
val rightKey = rightKeyGenerator
while (buildIter.hasNext) {
- currentRow = buildIter.next()
+ val currentRow = buildIter.next()
+ numBuildRows += 1
val rowKey = rightKey(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
@@ -76,30 +79,41 @@ trait HashSemiJoin {
}
}
}
+
hashSet
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
- hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
+ numStreamRows: LongSQLMetric,
+ hashSet: java.util.Set[InternalRow],
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator
streamIter.filter(current => {
+ numStreamRows += 1
val key = joinKeys(current)
- !key.anyNull && hashSet.contains(key)
+ val r = !key.anyNull && hashSet.contains(key)
+ if (r) numOutputRows += 1
+ r
})
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ numStreamRows: LongSQLMetric,
+ hashedRelation: HashedRelation,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator
val joinedRow = new JoinedRow
streamIter.filter { current =>
+ numStreamRows += 1
val key = joinKeys(current)
lazy val rowBuffer = hashedRelation.get(key)
- !key.anyNull && rowBuffer != null && rowBuffer.exists {
+ val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
}
+ if (r) numOutputRows += 1
+ r
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 26dbc911e9521..ea02076b41a6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,18 +17,21 @@
package org.apache.spark.sql.execution.joins
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
+import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.{SparkConf, SparkEnv}
/**
@@ -63,7 +66,8 @@ private[joins] final class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {
- private def this() = this(null) // Needed for serialization
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)
@@ -85,7 +89,8 @@ private[joins]
final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends HashedRelation with Externalizable {
- private def this() = this(null) // Needed for serialization
+ // Needed for serialization (it is public to make Java serialization work)
+ def this() = this(null)
override def get(key: InternalRow): Seq[InternalRow] = {
val v = hashTable.get(key)
@@ -110,11 +115,13 @@ private[joins] object HashedRelation {
def apply(
input: Iterator[InternalRow],
+ numInputRows: LongSQLMetric,
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {
if (keyGenerator.isInstanceOf[UnsafeProjection]) {
- return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
+ return UnsafeHashedRelation(
+ input, numInputRows, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}
// TODO: Use Spark's HashMap implementation.
@@ -128,6 +135,7 @@ private[joins] object HashedRelation {
// Create a mapping of buildKeys -> rows
while (input.hasNext) {
currentRow = input.next()
+ numInputRows += 1
val rowKey = keyGenerator(currentRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
@@ -181,15 +189,36 @@ private[joins] final class UnsafeHashedRelation(
private[joins] def this() = this(null) // Needed for serialization
// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
+ // This is used in broadcast joins and distributed mode only
@transient private[this] var binaryMap: BytesToBytesMap = _
+ /**
+ * Return the size of the unsafe map on the executors.
+ *
+ * For broadcast joins, this hashed relation is bigger on the driver because it is
+ * represented as a Java hash map there. While serializing the map to the executors,
+ * however, we rehash the contents in a binary map to reduce the memory footprint on
+ * the executors.
+ *
+ * For non-broadcast joins or in local mode, return 0.
+ */
+ def getUnsafeSize: Long = {
+ if (binaryMap != null) {
+ binaryMap.getTotalMemoryConsumption
+ } else {
+ 0
+ }
+ }
+
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
if (binaryMap != null) {
// Used in Broadcast join
- val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes)
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc)
if (loc.isDefined) {
val buffer = CompactBuffer[UnsafeRow]()
@@ -197,8 +226,8 @@ private[joins] final class UnsafeHashedRelation(
var offset = loc.getValueAddress.getBaseOffset
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
while (offset < last) {
- val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
- val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
+ val numFields = Platform.getInt(base, offset)
+ val sizeInBytes = Platform.getInt(base, offset + 4)
offset += 8
val row = new UnsafeRow
@@ -212,12 +241,12 @@ private[joins] final class UnsafeHashedRelation(
}
} else {
- // Use the JavaHashMap in Local mode or ShuffleHashJoin
+ // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
hashTable.get(unsafeKey)
}
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
@@ -229,7 +258,7 @@ private[joins] final class UnsafeHashedRelation(
// write all the values as single byte array
var totalSize = 0L
var i = 0
- while (i < values.size) {
+ while (i < values.length) {
totalSize += values(i).getSizeInBytes + 4 + 4
i += 1
}
@@ -240,7 +269,7 @@ private[joins] final class UnsafeHashedRelation(
out.writeInt(totalSize.toInt)
out.write(key.getBytes)
i = 0
- while (i < values.size) {
+ while (i < values.length) {
// [num of fields] [num of bytes] [row bytes]
// write the integer in native order, so they can be read by UNSAFE.getInt()
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
@@ -256,17 +285,25 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
- val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
- .getSizeAsBytes("spark.buffer.pageSize", "64m")
+ val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes)
+ .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+ // Dummy shuffle memory manager which always grants all memory allocation requests.
+ // We use this because it doesn't make sense count shared broadcast variables' memory usage
+ // towards individual tasks' quotas. In the future, we should devise a better way of handling
+ // this.
+ val shuffleMemoryManager =
+ ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes)
binaryMap = new BytesToBytesMap(
- memoryManager,
- nKeys * 2, // reduce hash collision
+ taskMemoryManager,
+ shuffleMemoryManager,
+ (nKeys * 1.5 + 1).toInt, // reduce hash collision
pageSizeBytes)
var i = 0
@@ -275,20 +312,24 @@ private[joins] final class UnsafeHashedRelation(
while (i < nKeys) {
val keySize = in.readInt()
val valuesSize = in.readInt()
- if (keySize > keyBuffer.size) {
+ if (keySize > keyBuffer.length) {
keyBuffer = new Array[Byte](keySize)
}
in.readFully(keyBuffer, 0, keySize)
- if (valuesSize > valuesBuffer.size) {
+ if (valuesSize > valuesBuffer.length) {
valuesBuffer = new Array[Byte](valuesSize)
}
in.readFully(valuesBuffer, 0, valuesSize)
// put it into binary map
- val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
+ val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
assert(!loc.isDefined, "Duplicated key found!")
- loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
- valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+ val putSuceeded = loc.putNewKey(
+ keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+ valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
+ if (!putSuceeded) {
+ throw new IOException("Could not allocate memory to grow BytesToBytesMap")
+ }
i += 1
}
}
@@ -298,14 +339,17 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
+ numInputRows: LongSQLMetric,
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
+ // Use a Java hash table here because unsafe maps expect fixed size records
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
// Create a mapping of buildKeys -> rows
while (input.hasNext) {
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
+ numInputRows += 1
val rowKey = keyGenerator(unsafeRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index 4443455ef11fe..ad6362542f2ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -35,6 +36,11 @@ case class LeftSemiJoinBNL(
extends BinaryNode {
// TODO: Override requiredChildDistribution.
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def output: Seq[Attribute] = left.output
@@ -52,13 +58,21 @@ case class LeftSemiJoinBNL(
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
protected override def doExecute(): RDD[InternalRow] = {
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
+
val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+ sparkContext.broadcast(broadcast.execute().map { row =>
+ numRightRows += 1
+ row.copy()
+ }.collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
streamedIter.filter(streamedRow => {
+ numLeftRows += 1
var i = 0
var matched = false
@@ -69,6 +83,9 @@ case class LeftSemiJoinBNL(
}
i += 1
}
+ if (matched) {
+ numOutputRows += 1
+ }
matched
})
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 26a664104d6fb..18808adaac63f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -21,8 +21,9 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -37,17 +38,28 @@ case class LeftSemiJoinHash(
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
- override def requiredChildDistribution: Seq[ClusteredDistribution] =
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
protected override def doExecute(): RDD[InternalRow] = {
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
+
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(buildIter)
- hashSemiJoin(streamIter, hashSet)
+ val hashSet = buildKeyHashSet(buildIter, numRightRows)
+ hashSemiJoin(streamIter, numLeftRows, hashSet, numOutputRows)
} else {
- val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
- hashSemiJoin(streamIter, hashRelation)
+ val hashRelation = HashedRelation(buildIter, numRightRows, rightKeyGenerator)
+ hashSemiJoin(streamIter, numLeftRows, hashRelation, numOutputRows)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5439e10a60b2a..fc8c9439a6f07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -21,8 +21,9 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -38,15 +39,27 @@ case class ShuffledHashJoin(
right: SparkPlan)
extends BinaryNode with HashJoin {
- override def outputPartitioning: Partitioning = left.outputPartitioning
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- override def requiredChildDistribution: Seq[ClusteredDistribution] =
+ override def outputPartitioning: Partitioning =
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+
+ override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
protected override def doExecute(): RDD[InternalRow] = {
+ val (numBuildRows, numStreamedRows) = buildSide match {
+ case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows"))
+ case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows"))
+ }
+ val numOutputRows = longMetric("numOutputRows")
+
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
- hashJoin(streamIter, hashed)
+ val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator)
+ hashJoin(streamIter, numStreamedRows, hashed, numOutputRows)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
index d29b593207c4d..ed282f98b7d71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -23,9 +23,10 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* :: DeveloperApi ::
@@ -41,41 +42,65 @@ case class ShuffledHashOuterJoin(
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashOuterJoin {
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ override def outputPartitioning: Partitioning = joinType match {
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case x =>
+ throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
+
val joinedRow = new JoinedRow()
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
// TODO this probably can be replaced by external sort (sort merged join?)
joinType match {
case LeftOuter =>
- val hashed = HashedRelation(rightIter, buildKeyGenerator)
+ val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator)
val keyGenerator = streamedKeyGenerator
+ val resultProj = resultProjection
leftIter.flatMap( currentRow => {
+ numLeftRows += 1
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey))
+ leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
})
case RightOuter =>
- val hashed = HashedRelation(leftIter, buildKeyGenerator)
+ val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator)
val keyGenerator = streamedKeyGenerator
+ val resultProj = resultProjection
rightIter.flatMap ( currentRow => {
+ numRightRows += 1
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow)
+ rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
})
case FullOuter =>
// TODO(davies): use UnsafeRow
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
+ val leftHashTable =
+ buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output))
+ val rightHashTable =
+ buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output))
(leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key,
leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST),
- joinedRow)
+ joinedRow,
+ numOutputRows)
}
case x =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index bb18b5403f8e8..6b7322671d6b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -17,15 +17,15 @@
package org.apache.spark.sql.execution.joins
-import java.util.NoSuchElementException
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
/**
* :: DeveloperApi ::
@@ -38,16 +38,19 @@ case class SortMergeJoin(
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
override def output: Seq[Attribute] = left.output ++ right.output
- override def outputPartitioning: Partitioning = left.outputPartitioning
+ override def outputPartitioning: Partitioning =
+ PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- // this is to manually construct an ordering that can be used to compare keys from both sides
- private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
-
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
@@ -56,113 +59,276 @@ case class SortMergeJoin(
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
- private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
+ protected[this] def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(rightKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
+ private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
+ // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
+ }
protected override def doExecute(): RDD[InternalRow] = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
- leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
- new Iterator[InternalRow] {
- // Mutable per row objects.
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ new RowIterator {
+ // An ordering that can be used to compare keys from both sides.
+ private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+ private[this] var currentMatchIdx: Int = -1
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ leftKeyGenerator,
+ rightKeyGenerator,
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ numLeftRows,
+ RowIterator.fromScala(rightIter),
+ numRightRows
+ )
private[this] val joinRow = new JoinedRow
- private[this] var leftElement: InternalRow = _
- private[this] var rightElement: InternalRow = _
- private[this] var leftKey: InternalRow = _
- private[this] var rightKey: InternalRow = _
- private[this] var rightMatches: CompactBuffer[InternalRow] = _
- private[this] var rightPosition: Int = -1
- private[this] var stop: Boolean = false
- private[this] var matchKey: InternalRow = _
-
- // initialize iterator
- initialize()
-
- override final def hasNext: Boolean = nextMatchingPair()
-
- override final def next(): InternalRow = {
- if (hasNext) {
- // we are using the buffered right rows and run down left iterator
- val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
- rightPosition += 1
- if (rightPosition >= rightMatches.size) {
- rightPosition = 0
- fetchLeft()
- if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
- stop = false
- rightMatches = null
- }
- }
- joinedRow
+ private[this] val resultProjection: (InternalRow) => InternalRow = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
} else {
- // no more result
- throw new NoSuchElementException
+ identity[InternalRow]
}
}
- private def fetchLeft() = {
- if (leftIter.hasNext) {
- leftElement = leftIter.next()
- leftKey = leftKeyGenerator(leftElement)
- } else {
- leftElement = null
+ override def advanceNext(): Boolean = {
+ if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ } else {
+ currentRightMatches = null
+ currentLeftRow = null
+ currentMatchIdx = -1
+ }
}
- }
-
- private def fetchRight() = {
- if (rightIter.hasNext) {
- rightElement = rightIter.next()
- rightKey = rightKeyGenerator(rightElement)
+ if (currentLeftRow != null) {
+ joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+ currentMatchIdx += 1
+ numOutputRows += 1
+ true
} else {
- rightElement = null
+ false
}
}
- private def initialize() = {
- fetchLeft()
- fetchRight()
+ override def getRow: InternalRow = resultProjection(joinRow)
+ }.toScala
+ }
+ }
+}
+
+/**
+ * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]].
+ *
+ * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]]
+ * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false`
+ * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return
+ * the matching row from the streamed input and may call [[getBufferedMatches]] to return the
+ * sequence of matching rows from the buffered input (in the case of an outer join, this will return
+ * an empty sequence if there are no matches from the buffered input). For efficiency, both of these
+ * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()`
+ * methods.
+ *
+ * @param streamedKeyGenerator a projection that produces join keys from the streamed input.
+ * @param bufferedKeyGenerator a projection that produces join keys from the buffered input.
+ * @param keyOrdering an ordering which can be used to compare join keys.
+ * @param streamedIter an input whose rows will be streamed.
+ * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
+ * have the same join key.
+ */
+private[joins] class SortMergeJoinScanner(
+ streamedKeyGenerator: Projection,
+ bufferedKeyGenerator: Projection,
+ keyOrdering: Ordering[InternalRow],
+ streamedIter: RowIterator,
+ numStreamedRows: LongSQLMetric,
+ bufferedIter: RowIterator,
+ numBufferedRows: LongSQLMetric) {
+ private[this] var streamedRow: InternalRow = _
+ private[this] var streamedRowKey: InternalRow = _
+ private[this] var bufferedRow: InternalRow = _
+ // Note: this is guaranteed to never have any null columns:
+ private[this] var bufferedRowKey: InternalRow = _
+ /**
+ * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty
+ */
+ private[this] var matchJoinKey: InternalRow = _
+ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
+ private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+
+ // Initialization (note: do _not_ want to advance streamed here).
+ advancedBufferedToRowWithNullFreeJoinKey()
+
+ // --- Public methods ---------------------------------------------------------------------------
+
+ def getStreamedRow: InternalRow = streamedRow
+
+ def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+
+ /**
+ * Advances both input iterators, stopping when we have found rows with matching join keys.
+ * @return true if matching rows have been found and false otherwise. If this returns true, then
+ * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
+ * results.
+ */
+ final def findNextInnerJoinRows(): Boolean = {
+ while (advancedStreamed() && streamedRowKey.anyNull) {
+ // Advance the streamed side of the join until we find the next row whose join key contains
+ // no nulls or we hit the end of the streamed iterator.
+ }
+ if (streamedRow == null) {
+ // We have consumed the entire streamed iterator, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
+ // The new streamed row has the same join key as the previous row, so return the same matches.
+ true
+ } else if (bufferedRow == null) {
+ // The streamed row's join key does not match the current batch of buffered rows and there are
+ // no more rows to read from the buffered iterator, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // Advance both the streamed and buffered iterators to find the next pair of matching rows.
+ var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ do {
+ if (streamedRowKey.anyNull) {
+ advancedStreamed()
+ } else {
+ assert(!bufferedRowKey.anyNull)
+ comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
+ else if (comp < 0) advancedStreamed()
}
+ } while (streamedRow != null && bufferedRow != null && comp != 0)
+ if (streamedRow == null || bufferedRow == null) {
+ // We have either hit the end of one of the iterators, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // The streamed row's join key matches the current buffered row's join, so walk through the
+ // buffered iterator to buffer the rest of the matching rows.
+ assert(comp == 0)
+ bufferMatchingRows()
+ true
+ }
+ }
+ }
- /**
- * Searches the right iterator for the next rows that have matches in left side, and store
- * them in a buffer.
- *
- * @return true if the search is successful, and false if the right iterator runs out of
- * tuples.
- */
- private def nextMatchingPair(): Boolean = {
- if (!stop && rightElement != null) {
- // run both side to get the first match pair
- while (!stop && leftElement != null && rightElement != null) {
- val comparing = keyOrdering.compare(leftKey, rightKey)
- // for inner join, we need to filter those null keys
- stop = comparing == 0 && !leftKey.anyNull
- if (comparing > 0 || rightKey.anyNull) {
- fetchRight()
- } else if (comparing < 0 || leftKey.anyNull) {
- fetchLeft()
- }
- }
- rightMatches = new CompactBuffer[InternalRow]()
- if (stop) {
- stop = false
- // iterate the right side to buffer all rows that matches
- // as the records should be ordered, exit when we meet the first that not match
- while (!stop && rightElement != null) {
- rightMatches += rightElement
- fetchRight()
- stop = keyOrdering.compare(leftKey, rightKey) != 0
- }
- if (rightMatches.size > 0) {
- rightPosition = 0
- matchKey = leftKey
- }
- }
+ /**
+ * Advances the streamed input iterator and buffers all rows from the buffered input that
+ * have matching keys.
+ * @return true if the streamed iterator returned a row, false otherwise. If this returns true,
+ * then [getStreamedRow and [[getBufferedMatches]] can be called to produce the outer
+ * join results.
+ */
+ final def findNextOuterJoinRows(): Boolean = {
+ if (!advancedStreamed()) {
+ // We have consumed the entire streamed iterator, so there can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
+ // Matches the current group, so do nothing.
+ } else {
+ // The streamed row does not match the current group.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ if (bufferedRow != null && !streamedRowKey.anyNull) {
+ // The buffered iterator could still contain matching rows, so we'll need to walk through
+ // it until we either find matches or pass where they would be found.
+ var comp = 1
+ do {
+ comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
+ if (comp == 0) {
+ // We have found matches, so buffer them (this updates matchJoinKey)
+ bufferMatchingRows()
+ } else {
+ // We have overshot the position where the row would be found, hence no matches.
}
- rightMatches != null && rightMatches.size > 0
}
}
+ // If there is a streamed input then we always return true
+ true
}
}
+
+ // --- Private methods --------------------------------------------------------------------------
+
+ /**
+ * Advance the streamed iterator and compute the new row's join key.
+ * @return true if the streamed iterator returned a row and false otherwise.
+ */
+ private def advancedStreamed(): Boolean = {
+ if (streamedIter.advanceNext()) {
+ streamedRow = streamedIter.getRow
+ streamedRowKey = streamedKeyGenerator(streamedRow)
+ numStreamedRows += 1
+ true
+ } else {
+ streamedRow = null
+ streamedRowKey = null
+ false
+ }
+ }
+
+ /**
+ * Advance the buffered iterator until we find a row with join key that does not contain nulls.
+ * @return true if the buffered iterator returned a row and false otherwise.
+ */
+ private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
+ var foundRow: Boolean = false
+ while (!foundRow && bufferedIter.advanceNext()) {
+ bufferedRow = bufferedIter.getRow
+ bufferedRowKey = bufferedKeyGenerator(bufferedRow)
+ numBufferedRows += 1
+ foundRow = !bufferedRowKey.anyNull
+ }
+ if (!foundRow) {
+ bufferedRow = null
+ bufferedRowKey = null
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Called when the streamed and buffered join keys match in order to buffer the matching rows.
+ */
+ private def bufferMatchingRows(): Unit = {
+ assert(streamedRowKey != null)
+ assert(!streamedRowKey.anyNull)
+ assert(bufferedRowKey != null)
+ assert(!bufferedRowKey.anyNull)
+ assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+ // This join key may have been produced by a mutable projection, so we need to make a copy:
+ matchJoinKey = streamedRowKey.copy()
+ bufferedMatches.clear()
+ do {
+ bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+ advancedBufferedToRowWithNullFreeJoinKey()
+ } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
new file mode 100644
index 0000000000000..dea9e5e580a1e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an sort merge outer join of two child relations.
+ *
+ * Note: this does not support full outer join yet; see SPARK-9730 for progress on this.
+ */
+@DeveloperApi
+case class SortMergeOuterJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override private[sql] lazy val metrics = Map(
+ "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
+ "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case x =>
+ throw new IllegalArgumentException(
+ s"${getClass.getSimpleName} should not take $x as the JoinType")
+ }
+ }
+
+ override def outputPartitioning: Partitioning = joinType match {
+ // For left and right outer joins, the output is partitioned by the streamed input's join keys.
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case x =>
+ throw new IllegalArgumentException(
+ s"${getClass.getSimpleName} should not take $x as the JoinType")
+ }
+
+ override def outputOrdering: Seq[SortOrder] = joinType match {
+ // For left and right outer joins, the output is ordered by the streamed input's join keys.
+ case LeftOuter => requiredOrders(leftKeys)
+ case RightOuter => requiredOrders(rightKeys)
+ case x => throw new IllegalArgumentException(
+ s"SortMergeOuterJoin should not take $x as the JoinType")
+ }
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+ private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
+ // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
+ keys.map(SortOrder(_, Ascending))
+ }
+
+ private def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(rightKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
+ private def createLeftKeyGenerator(): Projection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(leftKeys, left.output)
+ } else {
+ newProjection(leftKeys, left.output)
+ }
+ }
+
+ private def createRightKeyGenerator(): Projection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(rightKeys, right.output)
+ } else {
+ newProjection(rightKeys, right.output)
+ }
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ val numLeftRows = longMetric("numLeftRows")
+ val numRightRows = longMetric("numRightRows")
+ val numOutputRows = longMetric("numOutputRows")
+
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+ // An ordering that can be used to compare keys from both sides.
+ val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+ val boundCondition: (InternalRow) => Boolean = {
+ condition.map { cond =>
+ newPredicate(cond, left.output ++ right.output)
+ }.getOrElse {
+ (r: InternalRow) => true
+ }
+ }
+ val resultProj: InternalRow => InternalRow = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
+ } else {
+ identity[InternalRow]
+ }
+ }
+
+ joinType match {
+ case LeftOuter =>
+ val smjScanner = new SortMergeJoinScanner(
+ streamedKeyGenerator = createLeftKeyGenerator(),
+ bufferedKeyGenerator = createRightKeyGenerator(),
+ keyOrdering,
+ streamedIter = RowIterator.fromScala(leftIter),
+ numLeftRows,
+ bufferedIter = RowIterator.fromScala(rightIter),
+ numRightRows
+ )
+ val rightNullRow = new GenericInternalRow(right.output.length)
+ new LeftOuterIterator(
+ smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
+
+ case RightOuter =>
+ val smjScanner = new SortMergeJoinScanner(
+ streamedKeyGenerator = createRightKeyGenerator(),
+ bufferedKeyGenerator = createLeftKeyGenerator(),
+ keyOrdering,
+ streamedIter = RowIterator.fromScala(rightIter),
+ numRightRows,
+ bufferedIter = RowIterator.fromScala(leftIter),
+ numLeftRows
+ )
+ val leftNullRow = new GenericInternalRow(left.output.length)
+ new RightOuterIterator(
+ smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
+
+ case x =>
+ throw new IllegalArgumentException(
+ s"SortMergeOuterJoin should not take $x as the JoinType")
+ }
+ }
+ }
+}
+
+
+private class LeftOuterIterator(
+ smjScanner: SortMergeJoinScanner,
+ rightNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numRows: LongSQLMetric
+ ) extends RowIterator {
+ private[this] val joinedRow: JoinedRow = new JoinedRow()
+ private[this] var rightIdx: Int = 0
+ assert(smjScanner.getBufferedMatches.length == 0)
+
+ private def advanceLeft(): Boolean = {
+ rightIdx = 0
+ if (smjScanner.findNextOuterJoinRows()) {
+ joinedRow.withLeft(smjScanner.getStreamedRow)
+ if (smjScanner.getBufferedMatches.isEmpty) {
+ // There are no matching right rows, so return nulls for the right row
+ joinedRow.withRight(rightNullRow)
+ } else {
+ // Find the next row from the right input that satisfied the bound condition
+ if (!advanceRightUntilBoundConditionSatisfied()) {
+ joinedRow.withRight(rightNullRow)
+ }
+ }
+ true
+ } else {
+ // Left input has been exhausted
+ false
+ }
+ }
+
+ private def advanceRightUntilBoundConditionSatisfied(): Boolean = {
+ var foundMatch: Boolean = false
+ while (!foundMatch && rightIdx < smjScanner.getBufferedMatches.length) {
+ foundMatch = boundCondition(joinedRow.withRight(smjScanner.getBufferedMatches(rightIdx)))
+ rightIdx += 1
+ }
+ foundMatch
+ }
+
+ override def advanceNext(): Boolean = {
+ val r = advanceRightUntilBoundConditionSatisfied() || advanceLeft()
+ if (r) numRows += 1
+ r
+ }
+
+ override def getRow: InternalRow = resultProj(joinedRow)
+}
+
+private class RightOuterIterator(
+ smjScanner: SortMergeJoinScanner,
+ leftNullRow: InternalRow,
+ boundCondition: InternalRow => Boolean,
+ resultProj: InternalRow => InternalRow,
+ numRows: LongSQLMetric
+ ) extends RowIterator {
+ private[this] val joinedRow: JoinedRow = new JoinedRow()
+ private[this] var leftIdx: Int = 0
+ assert(smjScanner.getBufferedMatches.length == 0)
+
+ private def advanceRight(): Boolean = {
+ leftIdx = 0
+ if (smjScanner.findNextOuterJoinRows()) {
+ joinedRow.withRight(smjScanner.getStreamedRow)
+ if (smjScanner.getBufferedMatches.isEmpty) {
+ // There are no matching left rows, so return nulls for the left row
+ joinedRow.withLeft(leftNullRow)
+ } else {
+ // Find the next row from the left input that satisfied the bound condition
+ if (!advanceLeftUntilBoundConditionSatisfied()) {
+ joinedRow.withLeft(leftNullRow)
+ }
+ }
+ true
+ } else {
+ // Right input has been exhausted
+ false
+ }
+ }
+
+ private def advanceLeftUntilBoundConditionSatisfied(): Boolean = {
+ var foundMatch: Boolean = false
+ while (!foundMatch && leftIdx < smjScanner.getBufferedMatches.length) {
+ foundMatch = boundCondition(joinedRow.withLeft(smjScanner.getBufferedMatches(leftIdx)))
+ leftIdx += 1
+ }
+ foundMatch
+ }
+
+ override def advanceNext(): Boolean = {
+ val r = advanceLeftUntilBoundConditionSatisfied() || advanceRight()
+ if (r) numRows += 1
+ r
+ }
+
+ override def getRow: InternalRow = resultProj(joinedRow)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
new file mode 100644
index 0000000000000..7a2a98ec18cb8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -0,0 +1,121 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.metric
+
+import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ *
+ * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
+ */
+private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
+ name: String, val param: SQLMetricParam[R, T])
+ extends Accumulable[R, T](param.zero, param, Some(name), true)
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ */
+private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
+
+ def zero: R
+}
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ */
+private[sql] trait SQLMetricValue[T] extends Serializable {
+
+ def value: T
+
+ override def toString: String = value.toString
+}
+
+/**
+ * A wrapper of Long to avoid boxing and unboxing when using Accumulator
+ */
+private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
+
+ def add(incr: Long): LongSQLMetricValue = {
+ _value += incr
+ this
+ }
+
+ // Although there is a boxing here, it's fine because it's only called in SQLListener
+ override def value: Long = _value
+}
+
+/**
+ * A wrapper of Int to avoid boxing and unboxing when using Accumulator
+ */
+private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] {
+
+ def add(term: Int): IntSQLMetricValue = {
+ _value += term
+ this
+ }
+
+ // Although there is a boxing here, it's fine because it's only called in SQLListener
+ override def value: Int = _value
+}
+
+/**
+ * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
+ * `+=` and `add`.
+ */
+private[sql] class LongSQLMetric private[metric](name: String)
+ extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) {
+
+ override def +=(term: Long): Unit = {
+ localValue.add(term)
+ }
+
+ override def add(term: Long): Unit = {
+ localValue.add(term)
+ }
+}
+
+private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] {
+
+ override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
+
+ override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
+ r1.add(r2.value)
+
+ override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
+
+ override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L)
+}
+
+private[sql] object SQLMetrics {
+
+ def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ val acc = new LongSQLMetric(name)
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
+ }
+
+ /**
+ * A metric that its value will be ignored. Use this one when we need a metric parameter but don't
+ * care about the value.
+ */
+ val nullLongMetric = new LongSQLMetric("null")
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala
index 66237f8f1314b..28fa231e722d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala
@@ -18,12 +18,6 @@
package org.apache.spark.sql
/**
- * :: DeveloperApi ::
- * An execution engine for relational query plans that runs on top Spark and returns RDDs.
- *
- * Note that the operators in this package are created automatically by a query planner using a
- * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are
- * documented here in order to make it easier for others to understand the performance
- * characteristics of query plans that are generated by Spark SQL.
+ * The physical execution component of Spark SQL. Note that this is a private package.
*/
package object execution
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index ef1c6e57dc08a..59f8b079ab333 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -21,7 +21,6 @@ import java.io.OutputStream
import java.util.{List => JList, Map => JMap}
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import net.razorvine.pickle._
@@ -66,7 +65,7 @@ private[spark] case class PythonUDF(
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
@@ -135,22 +134,18 @@ object EvaluatePython {
new GenericInternalRowWithSchema(values, struct)
case (a: ArrayData, array: ArrayType) =>
- val length = a.numElements()
- val values = new java.util.ArrayList[Any](length)
- var i = 0
- while (i < length) {
- if (a.isNullAt(i)) {
- values.add(null)
- } else {
- values.add(toJava(a.get(i), array.elementType))
- }
- i += 1
- }
+ val values = new java.util.ArrayList[Any](a.numElements())
+ a.foreach(array.elementType, (_, e) => {
+ values.add(toJava(e, array.elementType))
+ })
values
- case (obj: Map[_, _], mt: MapType) => obj.map {
- case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
- }.asJava
+ case (map: MapData, mt: MapType) =>
+ val jmap = new java.util.HashMap[Any, Any](map.numElements())
+ map.foreach(mt.keyType, mt.valueType, (k, v) => {
+ jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType))
+ })
+ jmap
case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
@@ -186,7 +181,7 @@ object EvaluatePython {
case (c: Double, DoubleType) => c
- case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c)
+ case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)
case (c: Int, DateType) => c
@@ -206,9 +201,10 @@ object EvaluatePython {
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)))
- case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
- case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
- }.toMap
+ case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
+ val keys = c.keysIterator.map(fromJava(_, keyType)).toArray
+ val values = c.valuesIterator.map(fromJava(_, valueType)).toArray
+ ArrayBasedMapData(keys, values)
case (c, StructType(fields)) if c.getClass.isArray =>
new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
index 29f3beb3cb3c8..855555dd1d4c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
/**
@@ -33,6 +34,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe")
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputsUnsafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = false
override def canProcessSafeRows: Boolean = true
@@ -51,6 +54,8 @@ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
@DeveloperApi
case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputsUnsafeRows: Boolean = false
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index f82208868c3e3..e316930470127 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql.execution
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.RDD
+import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext}
+import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines various sort operators.
@@ -78,6 +77,11 @@ case class ExternalSort(
val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r.copy(), null)))
val baseIterator = sorter.iterator.map(_._1)
+ val context = TaskContext.get()
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
// TODO(marmbrus): The complex type signature below thwarts inference for no reason.
CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
@@ -97,59 +101,77 @@ case class ExternalSort(
* @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
* spill every `frequency` records.
*/
-case class UnsafeExternalSort(
+case class TungstenSort(
sortOrder: Seq[SortOrder],
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
extends UnaryNode {
- private[this] val schema: StructType = child.schema
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
- assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
- def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
- val ordering = newOrdering(sortOrder, child.output)
- val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
- // Hack until we generate separate comparator implementations for ascending vs. descending
- // (or choose to codegen them):
- val prefixComparator = {
- val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
- if (sortOrder.head.direction == Descending) {
- new PrefixComparator {
- override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
- }
- } else {
- comp
- }
- }
- val prefixComputer = {
- val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
- new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = prefixComputer(row)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val schema = child.schema
+ val childOutput = child.output
+ val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes
+
+ /**
+ * Set up the sorter in each partition before computing the parent partition.
+ * This makes sure our sorter is not starved by other sorters used in the same task.
+ */
+ def preparePartition(): UnsafeExternalRowSorter = {
+ val ordering = newOrdering(sortOrder, childOutput)
+
+ // The comparator for comparing prefix
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
+ val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+ // The generator for prefix
+ val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
}
}
- val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+
+ val sorter = new UnsafeExternalRowSorter(
+ schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
- sorter.sort(iterator)
+ sorter
}
- child.execute().mapPartitions(doSort, preservesPartitioning = true)
- }
- override def output: Seq[Attribute] = child.output
+ /** Compute a partition using the sorter already set up previously. */
+ def executePartition(
+ taskContext: TaskContext,
+ partitionIndex: Int,
+ sorter: UnsafeExternalRowSorter,
+ parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
+ taskContext.internalMetricsToAccumulators(
+ InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
+ sortedIterator
+ }
- override def outputOrdering: Seq[SortOrder] = sortOrder
+ // Note: we need to set up the external sorter in each partition before computing
+ // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
+ new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
+ child.execute(), preparePartition, executePartition, preservesPartitioning = true)
+ }
- override def outputsUnsafeRows: Boolean = true
}
-@DeveloperApi
-object UnsafeExternalSort {
+object TungstenSort {
/**
* Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 9329148aa233c..db463029aedf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -20,17 +20,15 @@ package org.apache.spark.sql.execution.stat
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.{Row, Column, DataFrame}
private[sql] object FrequentItems extends Logging {
/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
-
/**
* Add a new example to the counts if it exists, otherwise deduct the count
* from existing items.
@@ -42,9 +40,15 @@ private[sql] object FrequentItems extends Logging {
if (baseMap.size < size) {
baseMap += key -> count
} else {
- // TODO: Make this more efficient... A flatMap?
- baseMap.retain((k, v) => v > count)
- baseMap.transform((k, v) => v - count)
+ val minCount = baseMap.values.min
+ val remainder = count - minCount
+ if (remainder >= 0) {
+ baseMap += key -> count // something will get kicked out, so we can add this
+ baseMap.retain((k, v) => v > minCount)
+ baseMap.transform((k, v) => v - minCount)
+ } else {
+ baseMap.transform((k, v) => v - count)
+ }
}
}
this
@@ -90,12 +94,12 @@ private[sql] object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}.toArray
- val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)(
+ val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
val thisMap = counts(i)
- val key = row.get(i, colInfo(i)._2)
+ val key = row.get(i)
thisMap.add(key, 1L)
i += 1
}
@@ -110,13 +114,13 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
- val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_))
- val resultRow = InternalRow(justItems : _*)
+ val justItems = freqItems.map(m => m.baseMap.keys.toArray)
+ val resultRow = Row(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
- new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
+ new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
new file mode 100644
index 0000000000000..49646a99d68c8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.collection.mutable
+import scala.xml.Node
+
+import org.apache.commons.lang3.StringEscapeUtils
+
+import org.apache.spark.Logging
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging {
+
+ private val listener = parent.listener
+
+ override def render(request: HttpServletRequest): Seq[Node] = {
+ val currentTime = System.currentTimeMillis()
+ val content = listener.synchronized {
+ val _content = mutable.ListBuffer[Node]()
+ if (listener.getRunningExecutions.nonEmpty) {
+ _content ++=
+ new RunningExecutionTable(
+ parent, "Running Queries", currentTime,
+ listener.getRunningExecutions.sortBy(_.submissionTime).reverse).toNodeSeq
+ }
+ if (listener.getCompletedExecutions.nonEmpty) {
+ _content ++=
+ new CompletedExecutionTable(
+ parent, "Completed Queries", currentTime,
+ listener.getCompletedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq
+ }
+ if (listener.getFailedExecutions.nonEmpty) {
+ _content ++=
+ new FailedExecutionTable(
+ parent, "Failed Queries", currentTime,
+ listener.getFailedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq
+ }
+ _content
+ }
+ UIUtils.headerSparkPage("SQL", content, parent, Some(5000))
+ }
+}
+
+private[ui] abstract class ExecutionTable(
+ parent: SQLTab,
+ tableId: String,
+ tableName: String,
+ currentTime: Long,
+ executionUIDatas: Seq[SQLExecutionUIData],
+ showRunningJobs: Boolean,
+ showSucceededJobs: Boolean,
+ showFailedJobs: Boolean) {
+
+ protected def baseHeader: Seq[String] = Seq(
+ "ID",
+ "Description",
+ "Submitted",
+ "Duration")
+
+ protected def header: Seq[String]
+
+ protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = {
+ val submissionTime = executionUIData.submissionTime
+ val duration = executionUIData.completionTime.getOrElse(currentTime) - submissionTime
+
+ val runningJobs = executionUIData.runningJobs.map { jobId =>
+ {jobId.toString}
+ }
+ val succeededJobs = executionUIData.succeededJobs.sorted.map { jobId =>
+ {jobId.toString}
+ }
+ val failedJobs = executionUIData.failedJobs.sorted.map { jobId =>
+ {jobId.toString}
+ }
+