diff --git a/.gitignore b/.gitignore
index e1231c7374dc..125c5adb9689 100644
--- a/.gitignore
+++ b/.gitignore
@@ -25,6 +25,7 @@ R-unit-tests.log
R/unit-tests.out
R/cran-check.out
R/pkg/vignettes/sparkr-vignettes.html
+R/pkg/tests/fulltests/Rplots.pdf
build/*.jar
build/apache-maven*
build/scala*
diff --git a/LICENSE b/LICENSE
index 66a2e8f13295..39fe0dc46238 100644
--- a/LICENSE
+++ b/LICENSE
@@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
- (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/)
+ (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index b7fdae58de45..232f5cf31f31 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -429,6 +429,7 @@ export("structField",
"structField.character",
"print.structField",
"structType",
+ "structType.character",
"structType.jobj",
"structType.structField",
"print.structType")
@@ -465,5 +466,6 @@ S3method(print, summary.GBTRegressionModel)
S3method(print, summary.GBTClassificationModel)
S3method(structField, character)
S3method(structField, jobj)
+S3method(structType, character)
S3method(structType, jobj)
S3method(structType, structField)
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 3b9d42d6e715..e7a166c3014c 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1391,6 +1391,10 @@ setMethod("summarize",
})
dapplyInternal <- function(x, func, schema) {
+ if (is.character(schema)) {
+ schema <- structType(schema)
+ }
+
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)
@@ -1408,6 +1412,8 @@ dapplyInternal <- function(x, func, schema) {
dataFrame(sdf)
}
+setClassUnion("characterOrstructType", c("character", "structType"))
+
#' dapply
#'
#' Apply a function to each partition of a SparkDataFrame.
@@ -1418,10 +1424,11 @@ dapplyInternal <- function(x, func, schema) {
#' to each partition will be passed.
#' The output of func should be a R data.frame.
#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
-#' It must match the output of func.
+#' It must match the output of func. Since Spark 2.3, the DDL-formatted string
+#' is also supported for the schema.
#' @family SparkDataFrame functions
#' @rdname dapply
-#' @aliases dapply,SparkDataFrame,function,structType-method
+#' @aliases dapply,SparkDataFrame,function,characterOrstructType-method
#' @name dapply
#' @seealso \link{dapplyCollect}
#' @export
@@ -1444,6 +1451,17 @@ dapplyInternal <- function(x, func, schema) {
#' y <- cbind(y, y[1] + 1L)
#' },
#' schema)
+#'
+#' # The schema also can be specified in a DDL-formatted string.
+#' schema <- "a INT, d DOUBLE, c STRING, d INT"
+#' df1 <- dapply(
+#' df,
+#' function(x) {
+#' y <- x[x[1] > 1, ]
+#' y <- cbind(y, y[1] + 1L)
+#' },
+#' schema)
+#'
#' collect(df1)
#' # the result
#' # a b c d
@@ -1452,7 +1470,7 @@ dapplyInternal <- function(x, func, schema) {
#' }
#' @note dapply since 2.0.0
setMethod("dapply",
- signature(x = "SparkDataFrame", func = "function", schema = "structType"),
+ signature(x = "SparkDataFrame", func = "function", schema = "characterOrstructType"),
function(x, func, schema) {
dapplyInternal(x, func, schema)
})
@@ -1522,6 +1540,7 @@ setMethod("dapplyCollect",
#' @param schema the schema of the resulting SparkDataFrame after the function is applied.
#' The schema must match to output of \code{func}. It has to be defined for each
#' output column with preferred output column name and corresponding data type.
+#' Since Spark 2.3, the DDL-formatted string is also supported for the schema.
#' @return A SparkDataFrame.
#' @family SparkDataFrame functions
#' @aliases gapply,SparkDataFrame-method
@@ -1541,7 +1560,7 @@ setMethod("dapplyCollect",
#'
#' Here our output contains three columns, the key which is a combination of two
#' columns with data types integer and string and the mean which is a double.
-#' schema <- structType(structField("a", "integer"), structField("c", "string"),
+#' schema <- structType(structField("a", "integer"), structField("c", "string"),
#' structField("avg", "double"))
#' result <- gapply(
#' df,
@@ -1550,6 +1569,15 @@ setMethod("dapplyCollect",
#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
#' }, schema)
#'
+#' The schema also can be specified in a DDL-formatted string.
+#' schema <- "a INT, c STRING, avg DOUBLE"
+#' result <- gapply(
+#' df,
+#' c("a", "c"),
+#' function(key, x) {
+#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#' }, schema)
+#'
#' We can also group the data and afterwards call gapply on GroupedData.
#' For Example:
#' gdf <- group_by(df, "a", "c")
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index a1f5c4f8cc18..86507f13f038 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -38,10 +38,10 @@ NULL
#'
#' Date time functions defined for \code{Column}.
#'
-#' @param x Column to compute on.
+#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}.
#' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse
-#' x Column to DateType or TimestampType. For \code{trunc}, it is the string used
-#' for specifying the truncation method. For example, "year", "yyyy", "yy" for
+#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string
+#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for
#' truncate by year, or "month", "mon", "mm" for truncate by month.
#' @param ... additional argument(s).
#' @name column_datetime_functions
@@ -122,7 +122,7 @@ NULL
#' format to. See 'Details'.
#' }
#' @param y Column to compute on.
-#' @param ... additional columns.
+#' @param ... additional Columns.
#' @name column_string_functions
#' @rdname column_string_functions
#' @family string functions
@@ -167,8 +167,7 @@ NULL
#' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model),
#' v3 = hash(df$model, df$mpg), v4 = md5(df$model),
#' v5 = sha1(df$model), v6 = sha2(df$model, 256))
-#' head(tmp)
-#' }
+#' head(tmp)}
NULL
#' Collection functions for Column operations
@@ -190,7 +189,6 @@ NULL
#' \dontrun{
#' # Dataframe used throughout this doc
#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
-#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
@@ -200,6 +198,34 @@ NULL
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))}
NULL
+#' Window functions for Column operations
+#'
+#' Window functions defined for \code{Column}.
+#'
+#' @param x In \code{lag} and \code{lead}, it is the column as a character string or a Column
+#' to compute on. In \code{ntile}, it is the number of ntile groups.
+#' @param offset In \code{lag}, the number of rows back from the current row from which to obtain
+#' a value. In \code{lead}, the number of rows after the current row from which to
+#' obtain a value. If not specified, the default is 1.
+#' @param defaultValue (optional) default to use when the offset row does not exist.
+#' @param ... additional argument(s).
+#' @name column_window_functions
+#' @rdname column_window_functions
+#' @family window functions
+#' @examples
+#' \dontrun{
+#' # Dataframe used throughout this doc
+#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
+#' ws <- orderBy(windowPartitionBy("am"), "hp")
+#' tmp <- mutate(df, dist = over(cume_dist(), ws), dense_rank = over(dense_rank(), ws),
+#' lag = over(lag(df$mpg), ws), lead = over(lead(df$mpg, 1), ws),
+#' percent_rank = over(percent_rank(), ws),
+#' rank = over(rank(), ws), row_number = over(row_number(), ws))
+#' # Get ntile group id (1-4) for hp
+#' tmp <- mutate(tmp, ntile = over(ntile(4), ws))
+#' head(tmp)}
+NULL
+
#' @details
#' \code{lit}: A new Column is created to represent the literal value.
#' If the parameter is a Column, it is returned unchanged.
@@ -310,7 +336,8 @@ setMethod("asin",
})
#' @details
-#' \code{atan}: Computes the tangent inverse of the given value.
+#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range
+#' -pi/2 through pi/2.
#'
#' @rdname column_math_functions
#' @export
@@ -366,7 +393,7 @@ setMethod("base64",
})
#' @details
-#' \code{bin}: An expression that returns the string representation of the binary value
+#' \code{bin}: Returns the string representation of the binary value
#' of the given long column. For example, bin("12") returns "1100".
#'
#' @rdname column_math_functions
@@ -573,7 +600,7 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr
})
#' @details
-#' \code{cos}: Computes the cosine of the given value.
+#' \code{cos}: Computes the cosine of the given value. Units in radians.
#'
#' @rdname column_math_functions
#' @aliases cos cos,Column-method
@@ -694,7 +721,7 @@ setMethod("dayofyear",
#' \code{decode}: Computes the first argument into a string from a binary using the provided
#' character set.
#'
-#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE",
+#' @param charset character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE",
#' "UTF-16LE", "UTF-16").
#'
#' @rdname column_string_functions
@@ -827,7 +854,7 @@ setMethod("hex",
})
#' @details
-#' \code{hour}: Extracts the hours as an integer from a given date/timestamp/string.
+#' \code{hour}: Extracts the hour as an integer from a given date/timestamp/string.
#'
#' @rdname column_datetime_functions
#' @aliases hour hour,Column-method
@@ -1149,7 +1176,7 @@ setMethod("min",
})
#' @details
-#' \code{minute}: Extracts the minutes as an integer from a given date/timestamp/string.
+#' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string.
#'
#' @rdname column_datetime_functions
#' @aliases minute minute,Column-method
@@ -1326,7 +1353,7 @@ setMethod("sd",
})
#' @details
-#' \code{second}: Extracts the seconds as an integer from a given date/timestamp/string.
+#' \code{second}: Extracts the second as an integer from a given date/timestamp/string.
#'
#' @rdname column_datetime_functions
#' @aliases second second,Column-method
@@ -1381,7 +1408,7 @@ setMethod("sign", signature(x = "Column"),
})
#' @details
-#' \code{sin}: Computes the sine of the given value.
+#' \code{sin}: Computes the sine of the given value. Units in radians.
#'
#' @rdname column_math_functions
#' @aliases sin sin,Column-method
@@ -1436,20 +1463,18 @@ setMethod("soundex",
column(jc)
})
-#' Return the partition ID as a column
-#'
-#' Return the partition ID as a SparkDataFrame column.
+#' @details
+#' \code{spark_partition_id}: Returns the partition ID as a SparkDataFrame column.
#' Note that this is nondeterministic because it depends on data partitioning and
#' task scheduling.
+#' This is equivalent to the \code{SPARK_PARTITION_ID} function in SQL.
#'
-#' This is equivalent to the SPARK_PARTITION_ID function in SQL.
-#'
-#' @rdname spark_partition_id
-#' @name spark_partition_id
-#' @aliases spark_partition_id,missing-method
+#' @rdname column_nonaggregate_functions
+#' @aliases spark_partition_id spark_partition_id,missing-method
#' @export
#' @examples
-#' \dontrun{select(df, spark_partition_id())}
+#'
+#' \dontrun{head(select(df, spark_partition_id()))}
#' @note spark_partition_id since 2.0.0
setMethod("spark_partition_id",
signature("missing"),
@@ -1573,7 +1598,7 @@ setMethod("sumDistinct",
})
#' @details
-#' \code{tan}: Computes the tangent of the given value.
+#' \code{tan}: Computes the tangent of the given value. Units in radians.
#'
#' @rdname column_math_functions
#' @aliases tan tan,Column-method
@@ -1872,7 +1897,7 @@ setMethod("year",
#' @details
#' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates
-#' (x, y) to polar coordinates (r, theta).
+#' (x, y) to polar coordinates (r, theta). Units in radians.
#'
#' @rdname column_math_functions
#' @aliases atan2 atan2,Column-method
@@ -2000,7 +2025,7 @@ setMethod("pmod", signature(y = "Column"),
column(jc)
})
-#' @param rsd maximum estimation error allowed (default = 0.05)
+#' @param rsd maximum estimation error allowed (default = 0.05).
#'
#' @rdname column_aggregate_functions
#' @aliases approxCountDistinct,Column-method
@@ -2149,8 +2174,9 @@ setMethod("date_format", signature(y = "Column", x = "character"),
#'
#' @rdname column_collection_functions
#' @param schema a structType object to use as the schema to use when parsing the JSON string.
+#' Since Spark 2.3, the DDL-formatted string is also supported for the schema.
#' @param as.json.array indicating if input string is JSON array of objects or a single object.
-#' @aliases from_json from_json,Column,structType-method
+#' @aliases from_json from_json,Column,characterOrstructType-method
#' @export
#' @examples
#'
@@ -2163,10 +2189,15 @@ setMethod("date_format", signature(y = "Column", x = "character"),
#' df2 <- sql("SELECT named_struct('name', 'Bob') as people")
#' df2 <- mutate(df2, people_json = to_json(df2$people))
#' schema <- structType(structField("name", "string"))
-#' head(select(df2, from_json(df2$people_json, schema)))}
+#' head(select(df2, from_json(df2$people_json, schema)))
+#' head(select(df2, from_json(df2$people_json, "name STRING")))}
#' @note from_json since 2.2.0
-setMethod("from_json", signature(x = "Column", schema = "structType"),
+setMethod("from_json", signature(x = "Column", schema = "characterOrstructType"),
function(x, schema, as.json.array = FALSE, ...) {
+ if (is.character(schema)) {
+ schema <- structType(schema)
+ }
+
if (as.json.array) {
jschema <- callJStatic("org.apache.spark.sql.types.DataTypes",
"createArrayType",
@@ -2192,8 +2223,8 @@ setMethod("from_json", signature(x = "Column", schema = "structType"),
#' @examples
#'
#' \dontrun{
-#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, 'PST'),
-#' to_utc = to_utc_timestamp(df$time, 'PST'))
+#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, "PST"),
+#' to_utc = to_utc_timestamp(df$time, "PST"))
#' head(tmp)}
#' @note from_utc_timestamp since 1.5.0
setMethod("from_utc_timestamp", signature(y = "Column", x = "character"),
@@ -2227,7 +2258,7 @@ setMethod("instr", signature(y = "Column", x = "character"),
#' @details
#' \code{next_day}: Given a date column, returns the first date which is later than the value of
#' the date column that is on the specified day of the week. For example,
-#' \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first Sunday
+#' \code{next_day("2015-07-27", "Sunday")} returns 2015-08-02 because that is the first Sunday
#' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or
#' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun".
#'
@@ -2267,7 +2298,7 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"),
#' tmp <- mutate(df, t1 = add_months(df$time, 1),
#' t2 = date_add(df$time, 2),
#' t3 = date_sub(df$time, 3),
-#' t4 = next_day(df$time, 'Sun'))
+#' t4 = next_day(df$time, "Sun"))
#' head(tmp)}
#' @note add_months since 1.5.0
setMethod("add_months", signature(y = "Column", x = "numeric"),
@@ -2376,8 +2407,8 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"),
})
#' @details
-#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value,
-#' it will return a long value else it will return an integer value.
+#' \code{shiftRightUnsigned}: (Unigned) shifts the given value numBits right. If the given value is
+#' a long value, it will return a long value else it will return an integer value.
#'
#' @rdname column_math_functions
#' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method
@@ -2485,14 +2516,13 @@ setMethod("from_unixtime", signature(x = "Column"),
column(jc)
})
-#' window
-#'
-#' Bucketize rows into one or more time windows given a timestamp specifying column. Window
-#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
+#' @details
+#' \code{window}: Bucketizes rows into one or more time windows given a timestamp specifying column.
+#' Window starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
#' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
-#' the order of months are not supported.
+#' the order of months are not supported. It returns an output column of struct called 'window'
+#' by default with the nested columns 'start' and 'end'
#'
-#' @param x a time Column. Must be of TimestampType.
#' @param windowDuration a string specifying the width of the window, e.g. '1 second',
#' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week',
#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that
@@ -2508,27 +2538,22 @@ setMethod("from_unixtime", signature(x = "Column"),
#' window intervals. For example, in order to have hourly tumbling windows
#' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide
#' \code{startTime} as \code{"15 minutes"}.
-#' @param ... further arguments to be passed to or from other methods.
-#' @return An output column of struct called 'window' by default with the nested columns 'start'
-#' and 'end'.
-#' @family date time functions
-#' @rdname window
-#' @name window
-#' @aliases window,Column-method
+#' @rdname column_datetime_functions
+#' @aliases window window,Column-method
#' @export
#' @examples
-#'\dontrun{
-#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10,
-#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ...
-#' window(df$time, "1 minute", "15 seconds", "10 seconds")
#'
-#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15,
-#' # 09:01:15-09:02:15...
-#' window(df$time, "1 minute", startTime = "15 seconds")
+#' \dontrun{
+#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10,
+#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ...
+#' window(df$time, "1 minute", "15 seconds", "10 seconds")
#'
-#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ...
-#' window(df$time, "30 seconds", "10 seconds")
-#'}
+#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15,
+#' # 09:01:15-09:02:15...
+#' window(df$time, "1 minute", startTime = "15 seconds")
+#'
+#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ...
+#' window(df$time, "30 seconds", "10 seconds")}
#' @note window since 2.0.0
setMethod("window", signature(x = "Column"),
function(x, windowDuration, slideDuration = NULL, startTime = NULL) {
@@ -2844,27 +2869,16 @@ setMethod("ifelse",
###################### Window functions######################
-#' cume_dist
-#'
-#' Window function: returns the cumulative distribution of values within a window partition,
-#' i.e. the fraction of rows that are below the current row.
-#'
-#' N = total number of rows in the partition
-#' cume_dist(x) = number of values before (and including) x / N
-#'
+#' @details
+#' \code{cume_dist}: Returns the cumulative distribution of values within a window partition,
+#' i.e. the fraction of rows that are below the current row:
+#' (number of values before and including x) / (total number of rows in the partition).
#' This is equivalent to the \code{CUME_DIST} function in SQL.
+#' The method should be used with no argument.
#'
-#' @rdname cume_dist
-#' @name cume_dist
-#' @family window functions
-#' @aliases cume_dist,missing-method
+#' @rdname column_window_functions
+#' @aliases cume_dist cume_dist,missing-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#' out <- select(df, over(cume_dist(), ws), df$hp, df$am)
-#' }
#' @note cume_dist since 1.6.0
setMethod("cume_dist",
signature("missing"),
@@ -2873,28 +2887,19 @@ setMethod("cume_dist",
column(jc)
})
-#' dense_rank
-#'
-#' Window function: returns the rank of rows within a window partition, without any gaps.
+#' @details
+#' \code{dense_rank}: Returns the rank of rows within a window partition, without any gaps.
#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking
#' sequence when there are ties. That is, if you were ranking a competition using dense_rank
#' and had three people tie for second place, you would say that all three were in second
#' place and that the next person came in third. Rank would give me sequential numbers, making
#' the person that came in third place (after the ties) would register as coming in fifth.
-#'
#' This is equivalent to the \code{DENSE_RANK} function in SQL.
+#' The method should be used with no argument.
#'
-#' @rdname dense_rank
-#' @name dense_rank
-#' @family window functions
-#' @aliases dense_rank,missing-method
+#' @rdname column_window_functions
+#' @aliases dense_rank dense_rank,missing-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#' out <- select(df, over(dense_rank(), ws), df$hp, df$am)
-#' }
#' @note dense_rank since 1.6.0
setMethod("dense_rank",
signature("missing"),
@@ -2903,34 +2908,15 @@ setMethod("dense_rank",
column(jc)
})
-#' lag
-#'
-#' Window function: returns the value that is \code{offset} rows before the current row, and
+#' @details
+#' \code{lag}: Returns the value that is \code{offset} rows before the current row, and
#' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example,
#' an \code{offset} of one will return the previous row at any given point in the window partition.
-#'
#' This is equivalent to the \code{LAG} function in SQL.
#'
-#' @param x the column as a character string or a Column to compute on.
-#' @param offset the number of rows back from the current row from which to obtain a value.
-#' If not specified, the default is 1.
-#' @param defaultValue (optional) default to use when the offset row does not exist.
-#' @param ... further arguments to be passed to or from other methods.
-#' @rdname lag
-#' @name lag
-#' @aliases lag,characterOrColumn-method
-#' @family window functions
+#' @rdname column_window_functions
+#' @aliases lag lag,characterOrColumn-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#'
-#' # Partition by am (transmission) and order by hp (horsepower)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#'
-#' # Lag mpg values by 1 row on the partition-and-ordered table
-#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am)
-#' }
#' @note lag since 1.6.0
setMethod("lag",
signature(x = "characterOrColumn"),
@@ -2946,35 +2932,16 @@ setMethod("lag",
column(jc)
})
-#' lead
-#'
-#' Window function: returns the value that is \code{offset} rows after the current row, and
+#' @details
+#' \code{lead}: Returns the value that is \code{offset} rows after the current row, and
#' \code{defaultValue} if there is less than \code{offset} rows after the current row.
#' For example, an \code{offset} of one will return the next row at any given point
#' in the window partition.
-#'
#' This is equivalent to the \code{LEAD} function in SQL.
#'
-#' @param x the column as a character string or a Column to compute on.
-#' @param offset the number of rows after the current row from which to obtain a value.
-#' If not specified, the default is 1.
-#' @param defaultValue (optional) default to use when the offset row does not exist.
-#'
-#' @rdname lead
-#' @name lead
-#' @family window functions
-#' @aliases lead,characterOrColumn,numeric-method
+#' @rdname column_window_functions
+#' @aliases lead lead,characterOrColumn,numeric-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#'
-#' # Partition by am (transmission) and order by hp (horsepower)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#'
-#' # Lead mpg values by 1 row on the partition-and-ordered table
-#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am)
-#' }
#' @note lead since 1.6.0
setMethod("lead",
signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"),
@@ -2990,31 +2957,15 @@ setMethod("lead",
column(jc)
})
-#' ntile
-#'
-#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window
+#' @details
+#' \code{ntile}: Returns the ntile group id (from 1 to n inclusive) in an ordered window
#' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second
#' quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
-#'
#' This is equivalent to the \code{NTILE} function in SQL.
#'
-#' @param x Number of ntile groups
-#'
-#' @rdname ntile
-#' @name ntile
-#' @aliases ntile,numeric-method
-#' @family window functions
+#' @rdname column_window_functions
+#' @aliases ntile ntile,numeric-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#'
-#' # Partition by am (transmission) and order by hp (horsepower)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#'
-#' # Get ntile group id (1-4) for hp
-#' out <- select(df, over(ntile(4), ws), df$hp, df$am)
-#' }
#' @note ntile since 1.6.0
setMethod("ntile",
signature(x = "numeric"),
@@ -3023,27 +2974,15 @@ setMethod("ntile",
column(jc)
})
-#' percent_rank
-#'
-#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition.
-#'
-#' This is computed by:
-#'
-#' (rank of row in its partition - 1) / (number of rows in the partition - 1)
-#'
-#' This is equivalent to the PERCENT_RANK function in SQL.
+#' @details
+#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition.
+#' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1).
+#' This is equivalent to the \code{PERCENT_RANK} function in SQL.
+#' The method should be used with no argument.
#'
-#' @rdname percent_rank
-#' @name percent_rank
-#' @family window functions
-#' @aliases percent_rank,missing-method
+#' @rdname column_window_functions
+#' @aliases percent_rank percent_rank,missing-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#' out <- select(df, over(percent_rank(), ws), df$hp, df$am)
-#' }
#' @note percent_rank since 1.6.0
setMethod("percent_rank",
signature("missing"),
@@ -3052,29 +2991,19 @@ setMethod("percent_rank",
column(jc)
})
-#' rank
-#'
-#' Window function: returns the rank of rows within a window partition.
-#'
+#' @details
+#' \code{rank}: Returns the rank of rows within a window partition.
#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking
#' sequence when there are ties. That is, if you were ranking a competition using dense_rank
#' and had three people tie for second place, you would say that all three were in second
#' place and that the next person came in third. Rank would give me sequential numbers, making
#' the person that came in third place (after the ties) would register as coming in fifth.
+#' This is equivalent to the \code{RANK} function in SQL.
+#' The method should be used with no argument.
#'
-#' This is equivalent to the RANK function in SQL.
-#'
-#' @rdname rank
-#' @name rank
-#' @family window functions
-#' @aliases rank,missing-method
+#' @rdname column_window_functions
+#' @aliases rank rank,missing-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#' out <- select(df, over(rank(), ws), df$hp, df$am)
-#' }
#' @note rank since 1.6.0
setMethod("rank",
signature(x = "missing"),
@@ -3083,11 +3012,7 @@ setMethod("rank",
column(jc)
})
-# Expose rank() in the R base package
-#' @param x a numeric, complex, character or logical vector.
-#' @param ... additional argument(s) passed to the method.
-#' @name rank
-#' @rdname rank
+#' @rdname column_window_functions
#' @aliases rank,ANY-method
#' @export
setMethod("rank",
@@ -3096,23 +3021,14 @@ setMethod("rank",
base::rank(x, ...)
})
-#' row_number
-#'
-#' Window function: returns a sequential number starting at 1 within a window partition.
-#'
-#' This is equivalent to the ROW_NUMBER function in SQL.
+#' @details
+#' \code{row_number}: Returns a sequential number starting at 1 within a window partition.
+#' This is equivalent to the \code{ROW_NUMBER} function in SQL.
+#' The method should be used with no argument.
#'
-#' @rdname row_number
-#' @name row_number
-#' @aliases row_number,missing-method
-#' @family window functions
+#' @rdname column_window_functions
+#' @aliases row_number row_number,missing-method
#' @export
-#' @examples
-#' \dontrun{
-#' df <- createDataFrame(mtcars)
-#' ws <- orderBy(windowPartitionBy("am"), "hp")
-#' out <- select(df, over(row_number(), ws), df$hp, df$am)
-#' }
#' @note row_number since 1.6.0
setMethod("row_number",
signature("missing"),
@@ -3127,7 +3043,7 @@ setMethod("row_number",
#' \code{array_contains}: Returns null if the array is null, true if the array contains
#' the value, and false otherwise.
#'
-#' @param value A value to be checked if contained in the column
+#' @param value a value to be checked if contained in the column
#' @rdname column_collection_functions
#' @aliases array_contains array_contains,Column-method
#' @export
@@ -3172,7 +3088,7 @@ setMethod("size",
#' to the natural ordering of the array elements.
#'
#' @rdname column_collection_functions
-#' @param asc A logical flag indicating the sorting order.
+#' @param asc a logical flag indicating the sorting order.
#' TRUE, sorting is in ascending order.
#' FALSE, sorting is in descending order.
#' @aliases sort_array sort_array,Column-method
@@ -3299,7 +3215,7 @@ setMethod("split_string",
#' \code{repeat_string}: Repeats string n times.
#' Equivalent to \code{repeat} SQL function.
#'
-#' @param n Number of repetitions
+#' @param n number of repetitions.
#' @rdname column_string_functions
#' @aliases repeat_string repeat_string,Column-method
#' @export
@@ -3428,7 +3344,7 @@ setMethod("grouping_bit",
#' \code{grouping_id}: Returns the level of grouping.
#' Equals to \code{
#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn)
-#' }
+#' }.
#'
#' @rdname column_aggregate_functions
#' @aliases grouping_id grouping_id,Column-method
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index b901b74e4728..92098741f72f 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1013,9 +1013,9 @@ setGeneric("create_map", function(x, ...) { standardGeneric("create_map") })
#' @name NULL
setGeneric("hash", function(x, ...) { standardGeneric("hash") })
-#' @param x empty. Should be used with no argument.
-#' @rdname cume_dist
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") })
#' @rdname column_datetime_diff_functions
@@ -1053,9 +1053,9 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") })
#' @name NULL
setGeneric("decode", function(x, charset) { standardGeneric("decode") })
-#' @param x empty. Should be used with no argument.
-#' @rdname dense_rank
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") })
#' @rdname column_string_functions
@@ -1159,8 +1159,9 @@ setGeneric("isnan", function(x) { standardGeneric("isnan") })
#' @name NULL
setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") })
-#' @rdname lag
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("lag", function(x, ...) { standardGeneric("lag") })
#' @rdname last
@@ -1172,8 +1173,9 @@ setGeneric("last", function(x, ...) { standardGeneric("last") })
#' @name NULL
setGeneric("last_day", function(x) { standardGeneric("last_day") })
-#' @rdname lead
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") })
#' @rdname column_nonaggregate_functions
@@ -1260,8 +1262,9 @@ setGeneric("not", function(x) { standardGeneric("not") })
#' @name NULL
setGeneric("next_day", function(y, x) { standardGeneric("next_day") })
-#' @rdname ntile
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("ntile", function(x) { standardGeneric("ntile") })
#' @rdname column_aggregate_functions
@@ -1269,9 +1272,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") })
#' @name NULL
setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
-#' @param x empty. Should be used with no argument.
-#' @rdname percent_rank
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") })
#' @rdname column_math_functions
@@ -1304,8 +1307,9 @@ setGeneric("rand", function(seed) { standardGeneric("rand") })
#' @name NULL
setGeneric("randn", function(seed) { standardGeneric("randn") })
-#' @rdname rank
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("rank", function(x, ...) { standardGeneric("rank") })
#' @rdname column_string_functions
@@ -1334,9 +1338,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") })
#' @name NULL
setGeneric("rint", function(x) { standardGeneric("rint") })
-#' @param x empty. Should be used with no argument.
-#' @rdname row_number
+#' @rdname column_window_functions
#' @export
+#' @name NULL
setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") })
#' @rdname column_string_functions
@@ -1414,9 +1418,9 @@ setGeneric("split_string", function(x, pattern) { standardGeneric("split_string"
#' @name NULL
setGeneric("soundex", function(x) { standardGeneric("soundex") })
-#' @param x empty. Should be used with no argument.
-#' @rdname spark_partition_id
+#' @rdname column_nonaggregate_functions
#' @export
+#' @name NULL
setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") })
#' @rdname column_aggregate_functions
@@ -1534,8 +1538,9 @@ setGeneric("var_samp", function(x) { standardGeneric("var_samp") })
#' @name NULL
setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") })
-#' @rdname window
+#' @rdname column_datetime_functions
#' @export
+#' @name NULL
setGeneric("window", function(x, ...) { standardGeneric("window") })
#' @rdname column_datetime_functions
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 17f5283abead..0a7be0e99397 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -233,6 +233,9 @@ setMethod("gapplyCollect",
})
gapplyInternal <- function(x, func, schema) {
+ if (is.character(schema)) {
+ schema <- structType(schema)
+ }
packageNamesArr <- serialize(.sparkREnv[[".packages"]],
connection = NULL)
broadcastArr <- lapply(ls(.broadcastNames),
diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 2f1220a75278..75b1a74ee8c7 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
+#' Supported options: "skip" (filter out rows with invalid data),
+#' "error" (throw an error), "keep" (put invalid data in a special additional
+#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
- maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+ maxMemoryInMB = 256, cacheNodeIds = FALSE,
+ handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
@@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
new("RandomForestRegressionModel", jobj = jobj)
},
classification = {
+ handleInvalid <- match.arg(handleInvalid)
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
@@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
as.numeric(minInfoGain), as.integer(checkpointInterval),
as.character(featureSubsetStrategy), seed,
as.numeric(subsamplingRate),
- as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+ as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+ handleInvalid)
new("RandomForestClassificationModel", jobj = jobj)
}
)
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
index cb5bdb90175b..d1ed6833d5d0 100644
--- a/R/pkg/R/schema.R
+++ b/R/pkg/R/schema.R
@@ -23,18 +23,24 @@
#' Create a structType object that contains the metadata for a SparkDataFrame. Intended for
#' use with createDataFrame and toDF.
#'
-#' @param x a structField object (created with the field() function)
+#' @param x a structField object (created with the \code{structField} method). Since Spark 2.3,
+#' this can be a DDL-formatted string, which is a comma separated list of field
+#' definitions, e.g., "a INT, b STRING".
#' @param ... additional structField objects
#' @return a structType object
#' @rdname structType
#' @export
#' @examples
#'\dontrun{
-#' schema <- structType(structField("a", "integer"), structField("c", "string"),
+#' schema <- structType(structField("a", "integer"), structField("c", "string"),
#' structField("avg", "double"))
#' df1 <- gapply(df, list("a", "c"),
#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) },
#' schema)
+#' schema <- structType("a INT, c STRING, avg DOUBLE")
+#' df1 <- gapply(df, list("a", "c"),
+#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) },
+#' schema)
#' }
#' @note structType since 1.4.0
structType <- function(x, ...) {
@@ -68,6 +74,23 @@ structType.structField <- function(x, ...) {
structType(stObj)
}
+#' @rdname structType
+#' @method structType character
+#' @export
+structType.character <- function(x, ...) {
+ if (!is.character(x)) {
+ stop("schema must be a DDL-formatted string.")
+ }
+ if (length(list(...)) > 0) {
+ stop("multiple DDL-formatted strings are not supported")
+ }
+
+ stObj <- handledCallJStatic("org.apache.spark.sql.types.StructType",
+ "fromDDL",
+ x)
+ structType(stObj)
+}
+
#' Print a Spark StructType.
#'
#' This function prints the contents of a StructType returned from the
@@ -102,7 +125,7 @@ print.structType <- function(x, ...) {
#' field1 <- structField("a", "integer")
#' field2 <- structField("c", "string")
#' field3 <- structField("avg", "double")
-#' schema <- structType(field1, field2, field3)
+#' schema <- structType(field1, field2, field3)
#' df1 <- gapply(df, list("a", "c"),
#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) },
#' schema)
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index f2d2620e5447..81507ea7186a 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -113,7 +113,7 @@ sparkR.stop <- function() {
#' list(spark.executor.memory="4g"),
#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"),
#' c("one.jar", "two.jar", "three.jar"),
-#' c("com.databricks:spark-avro_2.10:2.0.1"))
+#' c("com.databricks:spark-avro_2.11:2.0.1"))
#'}
#' @note sparkR.init since 1.4.0
sparkR.init <- function(
@@ -357,7 +357,7 @@ sparkRHive.init <- function(jsc = NULL) {
#' sparkR.session("yarn-client", "SparkR", "/home/spark",
#' list(spark.executor.memory="4g"),
#' c("one.jar", "two.jar", "three.jar"),
-#' c("com.databricks:spark-avro_2.10:2.0.1"))
+#' c("com.databricks:spark-avro_2.11:2.0.1"))
#' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g")
#'}
#' @note sparkR.session since 2.0.0
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index 3a318b71ea06..2e31dc5f728c 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -30,8 +30,50 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
+# Waits indefinitely for a socket connecion by default.
+selectTimeout <- NULL
+
while (TRUE) {
- ready <- socketSelect(list(inputCon))
+ ready <- socketSelect(list(inputCon), timeout = selectTimeout)
+
+ # Note that the children should be terminated in the parent. If each child terminates
+ # itself, it appears that the resource is not released properly, that causes an unexpected
+ # termination of this daemon due to, for example, running out of file descriptors
+ # (see SPARK-21093). Therefore, the current implementation tries to retrieve children
+ # that are exited (but not terminated) and then sends a kill signal to terminate them properly
+ # in the parent.
+ #
+ # There are two paths that it attempts to send a signal to terminate the children in the parent.
+ #
+ # 1. Every second if any socket connection is not available and if there are child workers
+ # running.
+ # 2. Right after a socket connection is available.
+ #
+ # In other words, the parent attempts to send the signal to the children every second if
+ # any worker is running or right before launching other worker children from the following
+ # new socket connection.
+
+ # The process IDs of exited children are returned below.
+ children <- parallel:::selectChildren(timeout = 0)
+
+ if (is.integer(children)) {
+ lapply(children, function(child) {
+ # This should be the PIDs of exited children. Otherwise, this returns raw bytes if any data
+ # was sent from this child. In this case, we discard it.
+ pid <- parallel:::readChild(child)
+ if (is.integer(pid)) {
+ # This checks if the data from this child is the same pid of this selected child.
+ if (child == pid) {
+ # If so, we terminate this child.
+ tools::pskill(child, tools::SIGUSR1)
+ }
+ }
+ })
+ } else if (is.null(children)) {
+ # If it is NULL, there are no children. Waits indefinitely for a socket connecion.
+ selectTimeout <- NULL
+ }
+
if (ready) {
port <- SparkR:::readInt(inputCon)
# There is a small chance that it could be interrupted by signal, retry one time
@@ -44,12 +86,15 @@ while (TRUE) {
}
p <- parallel:::mcfork()
if (inherits(p, "masterProcess")) {
+ # Reach here because this is a child process.
close(inputCon)
Sys.setenv(SPARKR_WORKER_PORT = port)
try(source(script))
- # Set SIGUSR1 so that child can exit
- tools::pskill(Sys.getpid(), tools::SIGUSR1)
+ # Note that this mcexit does not fully terminate this child.
parallel:::mcexit(0L)
+ } else {
+ # Forking succeeded and we need to check if they finished their jobs every second.
+ selectTimeout <- 1
}
}
}
diff --git a/R/pkg/tests/fulltests/test_client.R b/R/pkg/tests/fulltests/test_client.R
index 0cf25fe1dbf3..de624b572cc2 100644
--- a/R/pkg/tests/fulltests/test_client.R
+++ b/R/pkg/tests/fulltests/test_client.R
@@ -37,7 +37,7 @@ test_that("multiple packages don't produce a warning", {
test_that("sparkJars sparkPackages as character vectors", {
args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "",
- c("com.databricks:spark-avro_2.10:2.0.1"))
+ c("com.databricks:spark-avro_2.11:2.0.1"))
expect_match(args, "--jars one.jar,two.jar,three.jar")
- expect_match(args, "--packages com.databricks:spark-avro_2.10:2.0.1")
+ expect_match(args, "--packages com.databricks:spark-avro_2.11:2.0.1")
})
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index 9b3fc8d270b2..66a0693a59a5 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -212,6 +212,23 @@ test_that("spark.randomForest", {
expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50)
+ # Test unseen labels
+ data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+ someString = base::sample(c("this", "that"), 10, replace = TRUE),
+ stringsAsFactors = FALSE)
+ trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+ traindf <- as.DataFrame(data[trainidxs, ])
+ testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+ model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+ maxDepth = 10, maxBins = 10, numTrees = 10)
+ predictions <- predict(model, testdf)
+ expect_error(collect(predictions))
+ model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+ maxDepth = 10, maxBins = 10, numTrees = 10,
+ handleInvalid = "skip")
+ predictions <- predict(model, testdf)
+ expect_equal(class(collect(predictions)$clicked[1]), "character")
+
# spark.randomForest classification can work on libsvm data
if (windows_with_hadoop()) {
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index a2bcb5aefe16..77052d4a2834 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -146,6 +146,13 @@ test_that("structType and structField", {
expect_is(testSchema, "structType")
expect_is(testSchema$fields()[[2]], "structField")
expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType")
+
+ testSchema <- structType("a STRING, b INT")
+ expect_is(testSchema, "structType")
+ expect_is(testSchema$fields()[[2]], "structField")
+ expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType")
+
+ expect_error(structType("A stri"), "DataType stri is not supported.")
})
test_that("structField type strings", {
@@ -1480,13 +1487,15 @@ test_that("column functions", {
j <- collect(select(df, alias(to_json(df$info), "json")))
expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}")
df <- as.DataFrame(j)
- schema <- structType(structField("age", "integer"),
- structField("height", "double"))
- s <- collect(select(df, alias(from_json(df$json, schema), "structcol")))
- expect_equal(ncol(s), 1)
- expect_equal(nrow(s), 3)
- expect_is(s[[1]][[1]], "struct")
- expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } )))
+ schemas <- list(structType(structField("age", "integer"), structField("height", "double")),
+ "age INT, height DOUBLE")
+ for (schema in schemas) {
+ s <- collect(select(df, alias(from_json(df$json, schema), "structcol")))
+ expect_equal(ncol(s), 1)
+ expect_equal(nrow(s), 3)
+ expect_is(s[[1]][[1]], "struct")
+ expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } )))
+ }
# passing option
df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}")))
@@ -1504,14 +1513,15 @@ test_that("column functions", {
# check if array type in string is correctly supported.
jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]"
df <- as.DataFrame(list(list("people" = jsonArr)))
- schema <- structType(structField("name", "string"))
- arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol")))
- expect_equal(ncol(arr), 1)
- expect_equal(nrow(arr), 1)
- expect_is(arr[[1]][[1]], "list")
- expect_equal(length(arr$arrcol[[1]]), 2)
- expect_equal(arr$arrcol[[1]][[1]]$name, "Bob")
- expect_equal(arr$arrcol[[1]][[2]]$name, "Alice")
+ for (schema in list(structType(structField("name", "string")), "name STRING")) {
+ arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol")))
+ expect_equal(ncol(arr), 1)
+ expect_equal(nrow(arr), 1)
+ expect_is(arr[[1]][[1]], "list")
+ expect_equal(length(arr$arrcol[[1]]), 2)
+ expect_equal(arr$arrcol[[1]][[1]]$name, "Bob")
+ expect_equal(arr$arrcol[[1]][[2]]$name, "Alice")
+ }
# Test create_array() and create_map()
df <- as.DataFrame(data.frame(
@@ -2885,30 +2895,33 @@ test_that("dapply() and dapplyCollect() on a DataFrame", {
expect_identical(ldf, result)
# Filter and add a column
- schema <- structType(structField("a", "integer"), structField("b", "double"),
- structField("c", "string"), structField("d", "integer"))
- df1 <- dapply(
- df,
- function(x) {
- y <- x[x$a > 1, ]
- y <- cbind(y, y$a + 1L)
- },
- schema)
- result <- collect(df1)
- expected <- ldf[ldf$a > 1, ]
- expected$d <- expected$a + 1L
- rownames(expected) <- NULL
- expect_identical(expected, result)
-
- result <- dapplyCollect(
- df,
- function(x) {
- y <- x[x$a > 1, ]
- y <- cbind(y, y$a + 1L)
- })
- expected1 <- expected
- names(expected1) <- names(result)
- expect_identical(expected1, result)
+ schemas <- list(structType(structField("a", "integer"), structField("b", "double"),
+ structField("c", "string"), structField("d", "integer")),
+ "a INT, b DOUBLE, c STRING, d INT")
+ for (schema in schemas) {
+ df1 <- dapply(
+ df,
+ function(x) {
+ y <- x[x$a > 1, ]
+ y <- cbind(y, y$a + 1L)
+ },
+ schema)
+ result <- collect(df1)
+ expected <- ldf[ldf$a > 1, ]
+ expected$d <- expected$a + 1L
+ rownames(expected) <- NULL
+ expect_identical(expected, result)
+
+ result <- dapplyCollect(
+ df,
+ function(x) {
+ y <- x[x$a > 1, ]
+ y <- cbind(y, y$a + 1L)
+ })
+ expected1 <- expected
+ names(expected1) <- names(result)
+ expect_identical(expected1, result)
+ }
# Remove the added column
df2 <- dapply(
@@ -3020,29 +3033,32 @@ test_that("gapply() and gapplyCollect() on a DataFrame", {
# Computes the sum of second column by grouping on the first and third columns
# and checks if the sum is larger than 2
- schema <- structType(structField("a", "integer"), structField("e", "boolean"))
- df2 <- gapply(
- df,
- c(df$"a", df$"c"),
- function(key, x) {
- y <- data.frame(key[1], sum(x$b) > 2)
- },
- schema)
- actual <- collect(df2)$e
- expected <- c(TRUE, TRUE)
- expect_identical(actual, expected)
-
- df2Collect <- gapplyCollect(
- df,
- c(df$"a", df$"c"),
- function(key, x) {
- y <- data.frame(key[1], sum(x$b) > 2)
- colnames(y) <- c("a", "e")
- y
- })
- actual <- df2Collect$e
+ schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")),
+ "a INT, e BOOLEAN")
+ for (schema in schemas) {
+ df2 <- gapply(
+ df,
+ c(df$"a", df$"c"),
+ function(key, x) {
+ y <- data.frame(key[1], sum(x$b) > 2)
+ },
+ schema)
+ actual <- collect(df2)$e
+ expected <- c(TRUE, TRUE)
expect_identical(actual, expected)
+ df2Collect <- gapplyCollect(
+ df,
+ c(df$"a", df$"c"),
+ function(key, x) {
+ y <- data.frame(key[1], sum(x$b) > 2)
+ colnames(y) <- c("a", "e")
+ y
+ })
+ actual <- df2Collect$e
+ expect_identical(actual, expected)
+ }
+
# Computes the arithmetic mean of the second column by grouping
# on the first and third columns. Output the groupping value and the average.
schema <- structType(structField("a", "integer"), structField("c", "string"),
diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd
index 0977025c2036..993aa31a4c37 100644
--- a/bin/load-spark-env.cmd
+++ b/bin/load-spark-env.cmd
@@ -35,21 +35,21 @@ if [%SPARK_ENV_LOADED%] == [] (
rem Setting SPARK_SCALA_VERSION if not already set.
-set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11"
-set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.10"
+rem set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11"
+rem set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.12"
if [%SPARK_SCALA_VERSION%] == [] (
- if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% (
- echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected."
- echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd."
- exit 1
- )
- if exist %ASSEMBLY_DIR2% (
+ rem if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% (
+ rem echo "Presence of build for multiple Scala versions detected."
+ rem echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd."
+ rem exit 1
+ rem )
+ rem if exist %ASSEMBLY_DIR2% (
set SPARK_SCALA_VERSION=2.11
- ) else (
- set SPARK_SCALA_VERSION=2.10
- )
+ rem ) else (
+ rem set SPARK_SCALA_VERSION=2.12
+ rem )
)
exit /b 0
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
index 8a2f709960a2..9de62039c51e 100644
--- a/bin/load-spark-env.sh
+++ b/bin/load-spark-env.sh
@@ -46,18 +46,18 @@ fi
if [ -z "$SPARK_SCALA_VERSION" ]; then
- ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11"
- ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.10"
+ #ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11"
+ #ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12"
- if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then
- echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2
- echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2
- exit 1
- fi
+ #if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then
+ # echo -e "Presence of build for multiple Scala versions detected." 1>&2
+ # echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2
+ # exit 1
+ #fi
- if [ -d "$ASSEMBLY_DIR2" ]; then
+ #if [ -d "$ASSEMBLY_DIR2" ]; then
export SPARK_SCALA_VERSION="2.11"
- else
- export SPARK_SCALA_VERSION="2.10"
- fi
+ #else
+ # export SPARK_SCALA_VERSION="2.12"
+ #fi
fi
diff --git a/bin/pyspark b/bin/pyspark
index 98387c2ec5b8..dd286277c1fc 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -57,7 +57,7 @@ export PYSPARK_PYTHON
# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
-export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH"
+export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH"
# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
@@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
- exec "$PYSPARK_DRIVER_PYTHON" -m "$1"
+ exec "$PYSPARK_DRIVER_PYTHON" -m "$@"
exit
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index f211c0873ad2..46d4d5c883cf 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
-set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH%
+set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
diff --git a/build/mvn b/build/mvn
index 20f78869a6ce..9beb8af2878a 100755
--- a/build/mvn
+++ b/build/mvn
@@ -87,13 +87,13 @@ install_mvn() {
# Install zinc under the build/ folder
install_zinc() {
- local zinc_path="zinc-0.3.11/bin/zinc"
+ local zinc_path="zinc-0.3.15/bin/zinc"
[ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1
local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com}
install_app \
- "${TYPESAFE_MIRROR}/zinc/0.3.11" \
- "zinc-0.3.11.tgz" \
+ "${TYPESAFE_MIRROR}/zinc/0.3.15" \
+ "zinc-0.3.15.tgz" \
"${zinc_path}"
ZINC_BIN="${_DIR}/${zinc_path}"
}
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 066970f24205..0254d0cefc36 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -90,7 +90,8 @@
org.apache.spark
spark-tags_${scala.binary.version}
-
+ test
+
+ runtime
log4j
@@ -2115,9 +2116,9 @@
${antlr4.version}
- ${jline.groupid}
+ jline
jline
- ${jline.version}
+ 2.12.1
org.apache.commons
@@ -2135,6 +2136,25 @@
paranamer
${paranamer.version}
+
+ org.apache.arrow
+ arrow-vector
+ ${arrow.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+ io.netty
+ netty-handler
+
+
+
@@ -2170,6 +2190,7 @@
-->
org.jboss.netty
org.codehaus.groovy
+ *:*_2.10
true
@@ -2224,6 +2245,8 @@
-unchecked
-deprecation
-feature
+ -explaintypes
+ -Yno-adapted-args
-Xms1024m
@@ -2597,6 +2620,7 @@
org.eclipse.jetty:jetty-util
org.eclipse.jetty:jetty-server
com.google.guava:guava
+ org.jpmml:*
@@ -2611,6 +2635,14 @@
com.google.common
${spark.shade.packageName}.guava
+
+ org.dmg.pmml
+ ${spark.shade.packageName}.dmg.pmml
+
+
+ org.jpmml
+ ${spark.shade.packageName}.jpmml
+
@@ -2912,44 +2944,6 @@
-
- scala-2.10
-
- scala-2.10
-
-
- 2.10.6
- 2.10
- ${scala.version}
- org.scala-lang
-
-
-
-
- org.apache.maven.plugins
- maven-enforcer-plugin
-
-
- enforce-versions
-
- enforce
-
-
-
-
-
- *:*_2.11
-
-
-
-
-
-
-
-
-
-
-
test-java-home
@@ -2960,16 +2954,18 @@
+
scala-2.11
-
- !scala-2.10
-
+
+
+
+
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 4cc4a7fc2eb7..ccbd1e92d5cb 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -92,19 +92,11 @@ object SparkBuild extends PomBuild {
val projectsMap: mutable.Map[String, Seq[Setting[_]]] = mutable.Map.empty
override val profiles = {
- val profiles = Properties.propOrNone("sbt.maven.profiles") orElse Properties.envOrNone("SBT_MAVEN_PROFILES") match {
+ Properties.envOrNone("SBT_MAVEN_PROFILES") match {
case None => Seq("sbt")
case Some(v) =>
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
}
-
- if (System.getProperty("scala-2.10") == "") {
- // To activate scala-2.10 profile, replace empty property value to non-empty value
- // in the same way as Maven which handles -Dname as -Dname=true before executes build process.
- // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082
- System.setProperty("scala-2.10", "true")
- }
- profiles
}
Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
@@ -240,9 +232,7 @@ object SparkBuild extends PomBuild {
},
javacJVMVersion := "1.8",
- // SBT Scala 2.10 build still doesn't support Java 8, because scalac 2.10 doesn't, but,
- // it also doesn't touch Java 8 code and it's OK to emit Java 7 bytecode in this case
- scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8"),
+ scalacJVMVersion := "1.8",
javacOptions in Compile ++= Seq(
"-encoding", "UTF-8",
@@ -492,7 +482,6 @@ object OldDeps {
def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq(
name := "old-deps",
- scalaVersion := "2.10.5",
libraryDependencies := allPreviousArtifactKeys.value.flatten
)
}
@@ -772,13 +761,7 @@ object CopyDependencies {
object TestSettings {
import BuildCommons._
- private val scalaBinaryVersion =
- if (System.getProperty("scala-2.10") == "true") {
- "2.10"
- } else {
- "2.11"
- }
-
+ private val scalaBinaryVersion = "2.11"
lazy val settings = Seq (
// Fork new JVMs for tests and set Java options for those
fork := true,
diff --git a/python/README.md b/python/README.md
index 0a5c8010b848..84ec88141cb0 100644
--- a/python/README.md
+++ b/python/README.md
@@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c
## Python Requirements
-At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas).
\ No newline at end of file
+At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas).
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 5e4cfb8ab6fe..09898f29950e 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build
PAPER ?=
BUILDDIR ?= _build
-export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip)
+export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip)
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip
deleted file mode 100644
index 8c3829e32872..000000000000
Binary files a/python/lib/py4j-0.10.4-src.zip and /dev/null differ
diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip
new file mode 100644
index 000000000000..2f8edcc0c0b8
Binary files /dev/null and b/python/lib/py4j-0.10.6-src.zip differ
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 9b345ac73f3d..948806a5c936 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1265,8 +1265,8 @@ def theta(self):
@inherit_doc
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable,
- JavaMLReadable):
+ HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
+ JavaMLWritable, JavaMLReadable):
"""
Classifier trainer based on the Multilayer Perceptron.
Each layer has sigmoid activation function, output layer has softmax.
@@ -1407,20 +1407,6 @@ def getStepSize(self):
"""
return self.getOrDefault(self.stepSize)
- @since("2.0.0")
- def setSolver(self, value):
- """
- Sets the value of :py:attr:`solver`.
- """
- return self._set(solver=value)
-
- @since("2.0.0")
- def getSolver(self):
- """
- Gets the value of solver or its default value.
- """
- return self.getOrDefault(self.solver)
-
@since("2.0.0")
def setInitialWeights(self, value):
"""
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 77de1cc18246..7eb1b9fac2f5 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -314,7 +314,8 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)
@inherit_doc
-class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
+class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
+ JavaMLReadable, JavaMLWritable):
"""
Maps a column of continuous features to a column of feature buckets.
@@ -398,20 +399,6 @@ def getSplits(self):
"""
return self.getOrDefault(self.splits)
- @since("2.1.0")
- def setHandleInvalid(self, value):
- """
- Sets the value of :py:attr:`handleInvalid`.
- """
- return self._set(handleInvalid=value)
-
- @since("2.1.0")
- def getHandleInvalid(self):
- """
- Gets the value of :py:attr:`handleInvalid` or its default value.
- """
- return self.getOrDefault(self.handleInvalid)
-
@inherit_doc
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -1623,7 +1610,8 @@ def getDegree(self):
@inherit_doc
-class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
+class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1743,20 +1731,6 @@ def getRelativeError(self):
"""
return self.getOrDefault(self.relativeError)
- @since("2.1.0")
- def setHandleInvalid(self, value):
- """
- Sets the value of :py:attr:`handleInvalid`.
- """
- return self._set(handleInvalid=value)
-
- @since("2.1.0")
- def getHandleInvalid(self):
- """
- Gets the value of :py:attr:`handleInvalid` or its default value.
- """
- return self.getOrDefault(self.handleInvalid)
-
def _create_model(self, java_model):
"""
Private method to convert the java_model to a Python model.
@@ -2132,6 +2106,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
"frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
typeConverter=TypeConverters.toString)
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
+ "labels or NULL values). Options are 'skip' (filter out rows with " +
+ "invalid data), error (throw an error), or 'keep' (put invalid data " +
+ "in a special additional bucket, at index numLabels).",
+ typeConverter=TypeConverters.toString)
+
@keyword_only
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
stringOrderType="frequencyDesc"):
@@ -2971,7 +2951,8 @@ def explainedVariance(self):
@inherit_doc
-class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid,
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -3014,6 +2995,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
True
>>> loadedRF.getLabelCol() == rf.getLabelCol()
True
+ >>> loadedRF.getHandleInvalid() == rf.getHandleInvalid()
+ True
>>> str(loadedRF)
'RFormula(y ~ x + s) (uid=...)'
>>> modelPath = temp_path + "/rFormulaModel"
@@ -3052,26 +3035,37 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
"RFormula drops the same category as R when encoding strings.",
typeConverter=TypeConverters.toString)
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
+ "Options are 'skip' (filter out rows with invalid values), " +
+ "'error' (throw an error), or 'keep' (put invalid data in a special " +
+ "additional bucket, at index numLabels).",
+ typeConverter=TypeConverters.toString)
+
@keyword_only
def __init__(self, formula=None, featuresCol="features", labelCol="label",
- forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
+ forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
+ handleInvalid="error"):
"""
__init__(self, formula=None, featuresCol="features", labelCol="label", \
- forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
+ forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
+ handleInvalid="error")
"""
super(RFormula, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
- self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
+ self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
+ handleInvalid="error")
kwargs = self._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.5.0")
def setParams(self, formula=None, featuresCol="features", labelCol="label",
- forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
+ forceIndexLabel=False, stringIndexerOrderType="frequencyDesc",
+ handleInvalid="error"):
"""
setParams(self, formula=None, featuresCol="features", labelCol="label", \
- forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
+ forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \
+ handleInvalid="error")
Sets params for RFormula.
"""
kwargs = self._input_kwargs
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 2d17f95b0c44..f0ff7a5f59ab 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
.. versionadded:: 1.4.0
"""
+ solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
+ "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
+
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
@@ -1371,17 +1374,22 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
"Only applicable to the Tweedie family.",
typeConverter=TypeConverters.toFloat)
+ solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
+ "options: irls.", typeConverter=TypeConverters.toString)
+ offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " +
+ "or empty, we treat all instance offsets as 0.0",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
- variancePower=0.0, linkPower=None):
+ variancePower=0.0, linkPower=None, offsetCol=None):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
- variancePower=0.0, linkPower=None)
+ variancePower=0.0, linkPower=None, offsetCol=None)
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -1397,12 +1405,12 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
- variancePower=0.0, linkPower=None):
+ variancePower=0.0, linkPower=None, offsetCol=None):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
- variancePower=0.0, linkPower=None)
+ variancePower=0.0, linkPower=None, offsetCol=None)
Sets params for generalized linear regression.
"""
kwargs = self._input_kwargs
@@ -1481,6 +1489,20 @@ def getLinkPower(self):
"""
return self.getOrDefault(self.linkPower)
+ @since("2.3.0")
+ def setOffsetCol(self, value):
+ """
+ Sets the value of :py:attr:`offsetCol`.
+ """
+ return self._set(offsetCol=value)
+
+ @since("2.3.0")
+ def getOffsetCol(self):
+ """
+ Gets the value of offsetCol or its default value.
+ """
+ return self.getOrDefault(self.offsetCol)
+
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 17a39472e1fe..787004765160 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self):
for i in range(0, len(expected)):
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
+ def test_string_indexer_handle_invalid(self):
+ df = self.spark.createDataFrame([
+ (0, "a"),
+ (1, "d"),
+ (2, None)], ["id", "label"])
+
+ si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
+ stringOrderType="alphabetAsc")
+ model1 = si1.fit(df)
+ td1 = model1.transform(df)
+ actual1 = td1.select("id", "indexed").collect()
+ expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
+ self.assertEqual(actual1, expected1)
+
+ si2 = si1.setHandleInvalid("skip")
+ model2 = si2.fit(df)
+ td2 = model2.transform(df)
+ actual2 = td2.select("id", "indexed").collect()
+ expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
+ self.assertEqual(actual2, expected2)
+
class HasInducedError(Params):
@@ -1270,6 +1291,20 @@ def test_tweedie_distribution(self):
self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
+ def test_offset(self):
+
+ df = self.spark.createDataFrame(
+ [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
+ (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)),
+ (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)),
+ (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"])
+
+ glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset")
+ model = glr.fit(df)
+ self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581],
+ atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4))
+
class FPGrowthTests(SparkSessionTestCase):
def setUp(self):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 873b00fbe534..334761009003 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -608,7 +608,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p
sort records by their keys.
>>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
- >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
+ >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, True)
>>> rdd2.glom().collect()
[[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
"""
@@ -627,7 +627,6 @@ def sortPartition(iterator):
def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
"""
Sorts this RDD, which is assumed to consist of (key, value) pairs.
- # noqa
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
>>> sc.parallelize(tmp).sortByKey().first()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ea5e00e9eeef..d5c2a7518b18 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -182,6 +182,23 @@ def loads(self, obj):
raise NotImplementedError
+class ArrowSerializer(FramedSerializer):
+ """
+ Serializes an Arrow stream.
+ """
+
+ def dumps(self, obj):
+ raise NotImplementedError
+
+ def loads(self, obj):
+ import pyarrow as pa
+ reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
+ return reader.read_all()
+
+ def __repr__(self):
+ return "ArrowSerializer"
+
+
class BatchedSerializer(Serializer):
"""
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 426f07cd9410..c44ab247fd3d 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -232,6 +232,23 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
+ @ignore_unicode_prefix
+ @since(2.3)
+ def registerJavaUDAF(self, name, javaClassName):
+ """Register a java UDAF so it can be used in SQL statements.
+
+ :param name: name of the UDAF
+ :param javaClassName: fully qualified name of java class
+
+ >>> sqlContext.registerJavaUDAF("javaUDAF",
+ ... "test.org.apache.spark.sql.MyDoubleAvg")
+ >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
+ >>> df.registerTempTable("df")
+ >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
+ [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
+ """
+ self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
+
# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
@@ -551,6 +568,12 @@ def __init__(self, sqlContext):
def register(self, name, f, returnType=StringType()):
return self.sqlContext.registerFunction(name, f, returnType)
+ def registerJavaFunction(self, name, javaClassName, returnType=None):
+ self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
+
+ def registerJavaUDAF(self, name, javaClassName):
+ self.sqlContext.registerJavaUDAF(name, javaClassName)
+
register.__doc__ = SQLContext.registerFunction.__doc__
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 0649271ed224..944739bcd207 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,7 +29,8 @@
from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
-from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
+from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
+ UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
@@ -833,6 +834,8 @@ def join(self, other, on=None, how=None):
else:
if how is None:
how = "inner"
+ if on is None:
+ on = self._jseq([])
assert isinstance(how, basestring), "how should be basestring"
jdf = self._jdf.join(other._jdf, on, how)
return DataFrame(jdf, self.sql_ctx)
@@ -1708,7 +1711,8 @@ def toDF(self, *cols):
@since(1.3)
def toPandas(self):
- """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
+ """
+ Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
This is only available if Pandas is installed and available.
@@ -1721,18 +1725,42 @@ def toPandas(self):
1 5 Bob
"""
import pandas as pd
+ if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true":
+ try:
+ import pyarrow
+ tables = self._collectAsArrow()
+ if tables:
+ table = pyarrow.concat_tables(tables)
+ return table.to_pandas()
+ else:
+ return pd.DataFrame.from_records([], columns=self.columns)
+ except ImportError as e:
+ msg = "note: pyarrow must be installed and available on calling Python process " \
+ "if using spark.sql.execution.arrow.enable=true"
+ raise ImportError("%s\n%s" % (e.message, msg))
+ else:
+ dtype = {}
+ for field in self.schema:
+ pandas_type = _to_corrected_pandas_type(field.dataType)
+ if pandas_type is not None:
+ dtype[field.name] = pandas_type
- dtype = {}
- for field in self.schema:
- pandas_type = _to_corrected_pandas_type(field.dataType)
- if pandas_type is not None:
- dtype[field.name] = pandas_type
+ pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
- pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+ for f, t in dtype.items():
+ pdf[f] = pdf[f].astype(t, copy=False)
+ return pdf
- for f, t in dtype.items():
- pdf[f] = pdf[f].astype(t, copy=False)
- return pdf
+ def _collectAsArrow(self):
+ """
+ Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
+ and available.
+
+ .. note:: Experimental.
+ """
+ with SCCallSiteSync(self._sc) as css:
+ port = self._jdf.collectAsArrowToPython()
+ return list(_load_from_socket(port, ArrowSerializer()))
##########################################################################################
# Pandas compatibility
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b14fc86fcb39..d2126b279f40 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -67,9 +67,14 @@ def _():
_.__doc__ = 'Window function: ' + doc
return _
+_lit_doc = """
+ Creates a :class:`Column` of literal value.
+ >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)
+ [Row(height=5, spark_user=True)]
+ """
_functions = {
- 'lit': 'Creates a :class:`Column` of literal value.',
+ 'lit': _lit_doc,
'col': 'Returns a :class:`Column` based on the given column name.',
'column': 'Returns a :class:`Column` based on the given column name.',
'asc': 'Returns a sort expression based on the ascending order of the given column name.',
@@ -95,10 +100,13 @@ def _():
'0.0 through pi.',
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
'-pi/2 through pi/2.',
- 'atan': 'Computes the tangent inverse of the given value.',
+ 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' +
+ '-pi/2 through pi/2',
'cbrt': 'Computes the cube-root of the given value.',
'ceil': 'Computes the ceiling of the given value.',
- 'cos': 'Computes the cosine of the given value.',
+ 'cos': """Computes the cosine of the given value.
+
+ :param col: :class:`DoubleType` column, units in radians.""",
'cosh': 'Computes the hyperbolic cosine of the given value.',
'exp': 'Computes the exponential of the given value.',
'expm1': 'Computes the exponential of the given value minus one.',
@@ -109,15 +117,33 @@ def _():
'rint': 'Returns the double value that is closest in value to the argument and' +
' is equal to a mathematical integer.',
'signum': 'Computes the signum of the given value.',
- 'sin': 'Computes the sine of the given value.',
+ 'sin': """Computes the sine of the given value.
+
+ :param col: :class:`DoubleType` column, units in radians.""",
'sinh': 'Computes the hyperbolic sine of the given value.',
- 'tan': 'Computes the tangent of the given value.',
+ 'tan': """Computes the tangent of the given value.
+
+ :param col: :class:`DoubleType` column, units in radians.""",
'tanh': 'Computes the hyperbolic tangent of the given value.',
- 'toDegrees': '.. note:: Deprecated in 2.1, use degrees instead.',
- 'toRadians': '.. note:: Deprecated in 2.1, use radians instead.',
+ 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.',
+ 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.',
'bitwiseNOT': 'Computes bitwise not.',
}
+_collect_list_doc = """
+ Aggregate function: returns a list of objects with duplicates.
+
+ >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
+ >>> df2.agg(collect_list('age')).collect()
+ [Row(collect_list(age)=[2, 5, 5])]
+ """
+_collect_set_doc = """
+ Aggregate function: returns a set of objects with duplicate elements eliminated.
+
+ >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
+ >>> df2.agg(collect_set('age')).collect()
+ [Row(collect_set(age)=[5, 2])]
+ """
_functions_1_6 = {
# unary math functions
'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
@@ -131,9 +157,8 @@ def _():
'var_pop': 'Aggregate function: returns the population variance of the values in a group.',
'skewness': 'Aggregate function: returns the skewness of the values in a group.',
'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.',
- 'collect_list': 'Aggregate function: returns a list of objects with duplicates.',
- 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' +
- ' eliminated.',
+ 'collect_list': _collect_list_doc,
+ 'collect_set': _collect_set_doc
}
_functions_2_1 = {
@@ -147,7 +172,7 @@ def _():
# math functions that take two arguments as input
_binary_mathfunctions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
- 'polar coordinates (r, theta).',
+ 'polar coordinates (r, theta). Units in radians.',
'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.',
}
@@ -200,17 +225,20 @@ def _():
@since(1.3)
def approxCountDistinct(col, rsd=None):
"""
- .. note:: Deprecated in 2.1, use approx_count_distinct instead.
+ .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead.
"""
return approx_count_distinct(col, rsd)
@since(2.1)
def approx_count_distinct(col, rsd=None):
- """Returns a new :class:`Column` for approximate distinct count of ``col``.
+ """Aggregate function: returns a new :class:`Column` for approximate distinct count of column `col`.
- >>> df.agg(approx_count_distinct(df.age).alias('c')).collect()
- [Row(c=2)]
+ :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more
+ efficient to use :func:`countDistinct`
+
+ >>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect()
+ [Row(distinct_ages=2)]
"""
sc = SparkContext._active_spark_context
if rsd is None:
@@ -267,8 +295,7 @@ def coalesce(*cols):
@since(1.6)
def corr(col1, col2):
- """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1``
- and ``col2``.
+ """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``.
>>> a = range(20)
>>> b = [2 * x for x in range(20)]
@@ -282,8 +309,7 @@ def corr(col1, col2):
@since(2.0)
def covar_pop(col1, col2):
- """Returns a new :class:`Column` for the population covariance of ``col1``
- and ``col2``.
+ """Returns a new :class:`Column` for the population covariance of ``col1`` and ``col2``.
>>> a = [1] * 10
>>> b = [1] * 10
@@ -297,8 +323,7 @@ def covar_pop(col1, col2):
@since(2.0)
def covar_samp(col1, col2):
- """Returns a new :class:`Column` for the sample covariance of ``col1``
- and ``col2``.
+ """Returns a new :class:`Column` for the sample covariance of ``col1`` and ``col2``.
>>> a = [1] * 10
>>> b = [1] * 10
@@ -450,7 +475,7 @@ def monotonically_increasing_id():
def nanvl(col1, col2):
"""Returns col1 if it is not NaN, or col2 if col1 is NaN.
- Both inputs should be floating point columns (DoubleType or FloatType).
+ Both inputs should be floating point columns (:class:`DoubleType` or :class:`FloatType`).
>>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b"))
>>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect()
@@ -460,10 +485,15 @@ def nanvl(col1, col2):
return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2)))
+@ignore_unicode_prefix
@since(1.4)
def rand(seed=None):
"""Generates a random column with independent and identically distributed (i.i.d.) samples
from U[0.0, 1.0].
+
+ >>> df.withColumn('rand', rand(seed=42) * 3).collect()
+ [Row(age=2, name=u'Alice', rand=1.1568609015300986),
+ Row(age=5, name=u'Bob', rand=1.403379671529166)]
"""
sc = SparkContext._active_spark_context
if seed is not None:
@@ -473,10 +503,15 @@ def rand(seed=None):
return Column(jc)
+@ignore_unicode_prefix
@since(1.4)
def randn(seed=None):
"""Generates a column with independent and identically distributed (i.i.d.) samples from
the standard normal distribution.
+
+ >>> df.withColumn('randn', randn(seed=42)).collect()
+ [Row(age=2, name=u'Alice', randn=-0.7556247885860078),
+ Row(age=5, name=u'Bob', randn=-0.0861619008451133)]
"""
sc = SparkContext._active_spark_context
if seed is not None:
@@ -760,7 +795,7 @@ def ntile(n):
@since(1.5)
def current_date():
"""
- Returns the current date as a date column.
+ Returns the current date as a :class:`DateType` column.
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.current_date())
@@ -768,7 +803,7 @@ def current_date():
def current_timestamp():
"""
- Returns the current timestamp as a timestamp column.
+ Returns the current timestamp as a :class:`TimestampType` column.
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.current_timestamp())
@@ -787,8 +822,8 @@ def date_format(date, format):
.. note:: Use when ever possible specialized functions like `year`. These benefit from a
specialized implementation.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect()
[Row(date=u'04/08/2015')]
"""
sc = SparkContext._active_spark_context
@@ -800,8 +835,8 @@ def year(col):
"""
Extract the year of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(year('a').alias('year')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(year('dt').alias('year')).collect()
[Row(year=2015)]
"""
sc = SparkContext._active_spark_context
@@ -813,8 +848,8 @@ def quarter(col):
"""
Extract the quarter of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(quarter('a').alias('quarter')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(quarter('dt').alias('quarter')).collect()
[Row(quarter=2)]
"""
sc = SparkContext._active_spark_context
@@ -826,8 +861,8 @@ def month(col):
"""
Extract the month of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(month('a').alias('month')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(month('dt').alias('month')).collect()
[Row(month=4)]
"""
sc = SparkContext._active_spark_context
@@ -839,8 +874,8 @@ def dayofmonth(col):
"""
Extract the day of the month of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(dayofmonth('a').alias('day')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(dayofmonth('dt').alias('day')).collect()
[Row(day=8)]
"""
sc = SparkContext._active_spark_context
@@ -852,8 +887,8 @@ def dayofyear(col):
"""
Extract the day of the year of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(dayofyear('a').alias('day')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(dayofyear('dt').alias('day')).collect()
[Row(day=98)]
"""
sc = SparkContext._active_spark_context
@@ -865,8 +900,8 @@ def hour(col):
"""
Extract the hours of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
- >>> df.select(hour('a').alias('hour')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])
+ >>> df.select(hour('ts').alias('hour')).collect()
[Row(hour=13)]
"""
sc = SparkContext._active_spark_context
@@ -878,8 +913,8 @@ def minute(col):
"""
Extract the minutes of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
- >>> df.select(minute('a').alias('minute')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])
+ >>> df.select(minute('ts').alias('minute')).collect()
[Row(minute=8)]
"""
sc = SparkContext._active_spark_context
@@ -891,8 +926,8 @@ def second(col):
"""
Extract the seconds of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a'])
- >>> df.select(second('a').alias('second')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts'])
+ >>> df.select(second('ts').alias('second')).collect()
[Row(second=15)]
"""
sc = SparkContext._active_spark_context
@@ -904,8 +939,8 @@ def weekofyear(col):
"""
Extract the week number of a given date as integer.
- >>> df = spark.createDataFrame([('2015-04-08',)], ['a'])
- >>> df.select(weekofyear(df.a).alias('week')).collect()
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(weekofyear(df.dt).alias('week')).collect()
[Row(week=15)]
"""
sc = SparkContext._active_spark_context
@@ -917,9 +952,9 @@ def date_add(start, days):
"""
Returns the date that is `days` days after `start`
- >>> df = spark.createDataFrame([('2015-04-08',)], ['d'])
- >>> df.select(date_add(df.d, 1).alias('d')).collect()
- [Row(d=datetime.date(2015, 4, 9))]
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(date_add(df.dt, 1).alias('next_date')).collect()
+ [Row(next_date=datetime.date(2015, 4, 9))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_add(_to_java_column(start), days))
@@ -930,9 +965,9 @@ def date_sub(start, days):
"""
Returns the date that is `days` days before `start`
- >>> df = spark.createDataFrame([('2015-04-08',)], ['d'])
- >>> df.select(date_sub(df.d, 1).alias('d')).collect()
- [Row(d=datetime.date(2015, 4, 7))]
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect()
+ [Row(prev_date=datetime.date(2015, 4, 7))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))
@@ -956,9 +991,9 @@ def add_months(start, months):
"""
Returns the date that is `months` months after `start`
- >>> df = spark.createDataFrame([('2015-04-08',)], ['d'])
- >>> df.select(add_months(df.d, 1).alias('d')).collect()
- [Row(d=datetime.date(2015, 5, 8))]
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(add_months(df.dt, 1).alias('next_month')).collect()
+ [Row(next_month=datetime.date(2015, 5, 8))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.add_months(_to_java_column(start), months))
@@ -969,8 +1004,8 @@ def months_between(date1, date2):
"""
Returns the number of months between date1 and date2.
- >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
- >>> df.select(months_between(df.t, df.d).alias('months')).collect()
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2'])
+ >>> df.select(months_between(df.date1, df.date2).alias('months')).collect()
[Row(months=3.9495967...)]
"""
sc = SparkContext._active_spark_context
@@ -1073,12 +1108,19 @@ def last_day(date):
return Column(sc._jvm.functions.last_day(_to_java_column(date)))
+@ignore_unicode_prefix
@since(1.5)
def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"):
"""
Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
representing the timestamp of that moment in the current system time zone in the given
format.
+
+ >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
+ >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time'])
+ >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect()
+ [Row(ts=u'2015-04-08 00:00:00')]
+ >>> spark.conf.unset("spark.sql.session.timeZone")
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format))
@@ -1092,6 +1134,12 @@ def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'):
locale, return null if fail.
if `timestamp` is None, then it returns current timestamp.
+
+ >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
+ >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect()
+ [Row(unix_time=1428476400)]
+ >>> spark.conf.unset("spark.sql.session.timeZone")
"""
sc = SparkContext._active_spark_context
if timestamp is None:
@@ -1106,8 +1154,8 @@ def from_utc_timestamp(timestamp, tz):
that corresponds to the same time of day in the given timezone.
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
- >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect()
- [Row(t=datetime.datetime(1997, 2, 28, 2, 30))]
+ >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect()
+ [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz))
@@ -1119,9 +1167,9 @@ def to_utc_timestamp(timestamp, tz):
Given a timestamp, which corresponds to a certain time of day in the given timezone, returns
another timestamp that corresponds to the same time of day in UTC.
- >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
- >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect()
- [Row(t=datetime.datetime(1997, 2, 28, 18, 30))]
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts'])
+ >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect()
+ [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
@@ -1839,15 +1887,20 @@ def from_json(col, schema, options={}):
string.
:param col: string column in json format
- :param schema: a StructType or ArrayType of StructType to use when parsing the json column
+ :param schema: a StructType or ArrayType of StructType to use when parsing the json column.
:param options: options to control parsing. accepts the same options as the json datasource
+ .. note:: Since Spark 2.3, the DDL-formatted string or a JSON format string is also
+ supported for ``schema``.
+
>>> from pyspark.sql.types import *
>>> data = [(1, '''{"a": 1}''')]
>>> schema = StructType([StructField("a", IntegerType())])
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=Row(a=1))]
+ >>> df.select(from_json(df.value, "a INT").alias("json")).collect()
+ [Row(json=Row(a=1))]
>>> data = [(1, '''[{"a": 1}]''')]
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
>>> df = spark.createDataFrame(data, ("key", "value"))
@@ -1856,7 +1909,9 @@ def from_json(col, schema, options={}):
"""
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.from_json(_to_java_column(col), schema.json(), options)
+ if isinstance(schema, DataType):
+ schema = schema.json()
+ jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options)
return Column(jc)
@@ -1983,15 +2038,25 @@ def __init__(self, func, returnType, name=None):
"{0}".format(type(func)))
self.func = func
- self.returnType = (
- returnType if isinstance(returnType, DataType)
- else _parse_datatype_string(returnType))
+ self._returnType = returnType
# Stores UserDefinedPythonFunctions jobj, once initialized
+ self._returnType_placeholder = None
self._judf_placeholder = None
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
+ @property
+ def returnType(self):
+ # This makes sure this is called after SparkContext is initialized.
+ # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
+ if self._returnType_placeholder is None:
+ if isinstance(self._returnType, DataType):
+ self._returnType_placeholder = self._returnType
+ else:
+ self._returnType_placeholder = _parse_datatype_string(self._returnType)
+ return self._returnType_placeholder
+
@property
def _judf(self):
# It is possible that concurrent access, to newly created UDF,
@@ -2096,7 +2161,7 @@ def _test():
sc = spark.sparkContext
globs['sc'] = sc
globs['spark'] = spark
- globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df'] = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)])
(failure_count, test_count) = doctest.testmod(
pyspark.sql.functions, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index e3bf0f35ea15..2cc0e2d1d7b8 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -33,7 +33,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
+from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
from pyspark.sql.utils import install_exception_handler
@@ -514,17 +514,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]
- verify_func = _verify_type if verifySchema else lambda _, t: True
if isinstance(schema, StructType):
+ verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
+
def prepare(obj):
- verify_func(obj, schema)
+ verify_func(obj)
return obj
elif isinstance(schema, DataType):
dataType = schema
schema = StructType().add("value", schema)
+ verify_func = _make_type_verifier(
+ dataType, name="field value") if verifySchema else lambda _: True
+
def prepare(obj):
- verify_func(obj, dataType)
+ verify_func(obj)
return obj,
else:
if isinstance(schema, list):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0a1cd6856b8e..29e48a6ccf76 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -57,13 +57,22 @@
from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
-from pyspark.sql.types import UserDefinedType, _infer_type
-from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
+from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
+from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
+_have_arrow = False
+try:
+ import pyarrow
+ _have_arrow = True
+except:
+ # No Arrow, but that's okay, we'll skip those tests
+ pass
+
+
class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
@@ -481,6 +490,16 @@ def test_udf_registration_returns_udf(self):
df.select(add_three("id").alias("plus_three")).collect()
)
+ def test_non_existed_udf(self):
+ spark = self.spark
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
+ lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
+
+ def test_non_existed_udaf(self):
+ spark = self.spark
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
+ lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
+
def test_multiLine_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
@@ -852,7 +871,7 @@ def test_convert_row_to_dict(self):
self.assertEqual(1.0, row.asDict()['d']['key'].c)
def test_udt(self):
- from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
+ from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
def check_datatype(datatype):
@@ -868,8 +887,8 @@ def check_datatype(datatype):
check_datatype(structtype_with_udt)
p = ExamplePoint(1.0, 2.0)
self.assertEqual(_infer_type(p), ExamplePointUDT())
- _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
- self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
+ _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
+ self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))
check_datatype(PythonOnlyUDT())
structtype_with_udt = StructType([StructField("label", DoubleType(), False),
@@ -877,8 +896,10 @@ def check_datatype(datatype):
check_datatype(structtype_with_udt)
p = PythonOnlyPoint(1.0, 2.0)
self.assertEqual(_infer_type(p), PythonOnlyUDT())
- _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
- self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+ _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
+ self.assertRaises(
+ ValueError,
+ lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0]))
def test_simple_udt_in_df(self):
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
@@ -1234,6 +1255,31 @@ def test_struct_type(self):
with self.assertRaises(TypeError):
not_a_field = struct1[9.9]
+ def test_parse_datatype_string(self):
+ from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
+ for k, t in _all_atomic_types.items():
+ if t != NullType:
+ self.assertEqual(t(), _parse_datatype_string(k))
+ self.assertEqual(IntegerType(), _parse_datatype_string("int"))
+ self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)"))
+ self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )"))
+ self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)"))
+ self.assertEqual(
+ ArrayType(IntegerType()),
+ _parse_datatype_string("array"))
+ self.assertEqual(
+ MapType(IntegerType(), DoubleType()),
+ _parse_datatype_string("map< int, double >"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("struct"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("a:int, c:double"))
+ self.assertEqual(
+ StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]),
+ _parse_datatype_string("a INT, c DOUBLE"))
+
def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
schema = StructType([StructField("f1", StringType(), True, None),
@@ -2021,6 +2067,22 @@ def test_toDF_with_schema_string(self):
self.assertEqual(df.schema.simpleString(), "struct")
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
+ def test_join_without_on(self):
+ df1 = self.spark.range(1).toDF("a")
+ df2 = self.spark.range(1).toDF("b")
+
+ try:
+ self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
+ self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
+
+ self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
+ actual = df1.join(df2, how="inner").collect()
+ expected = [Row(a=0, b=0)]
+ self.assertEqual(actual, expected)
+ finally:
+ # We should unset this. Otherwise, other tests are affected.
+ self.spark.conf.unset("spark.sql.crossJoin.enabled")
+
# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
@@ -2314,6 +2376,12 @@ def test_to_pandas(self):
self.assertEquals(types[2], np.bool)
self.assertEquals(types[3], np.float32)
+ def test_create_dataframe_from_array_of_long(self):
+ import array
+ data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]
+ df = self.spark.createDataFrame(data)
+ self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
+
class HiveSparkSubmitTests(SparkSubmitTests):
@@ -2620,6 +2688,262 @@ def range_frame_match():
importlib.reload(window)
+
+class DataTypeVerificationTests(unittest.TestCase):
+
+ def test_verify_type_exception_msg(self):
+ self.assertRaisesRegexp(
+ ValueError,
+ "test_name",
+ lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
+
+ schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
+ self.assertRaisesRegexp(
+ TypeError,
+ "field b in field a",
+ lambda: _make_type_verifier(schema)([["data"]]))
+
+ def test_verify_type_ok_nullable(self):
+ obj = None
+ types = [IntegerType(), FloatType(), StringType(), StructType([])]
+ for data_type in types:
+ try:
+ _make_type_verifier(data_type, nullable=True)(obj)
+ except Exception:
+ self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))
+
+ def test_verify_type_not_nullable(self):
+ import array
+ import datetime
+ import decimal
+
+ schema = StructType([
+ StructField('s', StringType(), nullable=False),
+ StructField('i', IntegerType(), nullable=True)])
+
+ class MyObj:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ # obj, data_type
+ success_spec = [
+ # String
+ ("", StringType()),
+ (u"", StringType()),
+ (1, StringType()),
+ (1.0, StringType()),
+ ([], StringType()),
+ ({}, StringType()),
+
+ # UDT
+ (ExamplePoint(1.0, 2.0), ExamplePointUDT()),
+
+ # Boolean
+ (True, BooleanType()),
+
+ # Byte
+ (-(2**7), ByteType()),
+ (2**7 - 1, ByteType()),
+
+ # Short
+ (-(2**15), ShortType()),
+ (2**15 - 1, ShortType()),
+
+ # Integer
+ (-(2**31), IntegerType()),
+ (2**31 - 1, IntegerType()),
+
+ # Long
+ (2**64, LongType()),
+
+ # Float & Double
+ (1.0, FloatType()),
+ (1.0, DoubleType()),
+
+ # Decimal
+ (decimal.Decimal("1.0"), DecimalType()),
+
+ # Binary
+ (bytearray([1, 2]), BinaryType()),
+
+ # Date/Timestamp
+ (datetime.date(2000, 1, 2), DateType()),
+ (datetime.datetime(2000, 1, 2, 3, 4), DateType()),
+ (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()),
+
+ # Array
+ ([], ArrayType(IntegerType())),
+ (["1", None], ArrayType(StringType(), containsNull=True)),
+ ([1, 2], ArrayType(IntegerType())),
+ ((1, 2), ArrayType(IntegerType())),
+ (array.array('h', [1, 2]), ArrayType(IntegerType())),
+
+ # Map
+ ({}, MapType(StringType(), IntegerType())),
+ ({"a": 1}, MapType(StringType(), IntegerType())),
+ ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)),
+
+ # Struct
+ ({"s": "a", "i": 1}, schema),
+ ({"s": "a", "i": None}, schema),
+ ({"s": "a"}, schema),
+ ({"s": "a", "f": 1.0}, schema),
+ (Row(s="a", i=1), schema),
+ (Row(s="a", i=None), schema),
+ (Row(s="a", i=1, f=1.0), schema),
+ (["a", 1], schema),
+ (["a", None], schema),
+ (("a", 1), schema),
+ (MyObj(s="a", i=1), schema),
+ (MyObj(s="a", i=None), schema),
+ (MyObj(s="a"), schema),
+ ]
+
+ # obj, data_type, exception class
+ failure_spec = [
+ # String (match anything but None)
+ (None, StringType(), ValueError),
+
+ # UDT
+ (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
+
+ # Boolean
+ (1, BooleanType(), TypeError),
+ ("True", BooleanType(), TypeError),
+ ([1], BooleanType(), TypeError),
+
+ # Byte
+ (-(2**7) - 1, ByteType(), ValueError),
+ (2**7, ByteType(), ValueError),
+ ("1", ByteType(), TypeError),
+ (1.0, ByteType(), TypeError),
+
+ # Short
+ (-(2**15) - 1, ShortType(), ValueError),
+ (2**15, ShortType(), ValueError),
+
+ # Integer
+ (-(2**31) - 1, IntegerType(), ValueError),
+ (2**31, IntegerType(), ValueError),
+
+ # Float & Double
+ (1, FloatType(), TypeError),
+ (1, DoubleType(), TypeError),
+
+ # Decimal
+ (1.0, DecimalType(), TypeError),
+ (1, DecimalType(), TypeError),
+ ("1.0", DecimalType(), TypeError),
+
+ # Binary
+ (1, BinaryType(), TypeError),
+
+ # Date/Timestamp
+ ("2000-01-02", DateType(), TypeError),
+ (946811040, TimestampType(), TypeError),
+
+ # Array
+ (["1", None], ArrayType(StringType(), containsNull=False), ValueError),
+ ([1, "2"], ArrayType(IntegerType()), TypeError),
+
+ # Map
+ ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError),
+ ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError),
+ ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False),
+ ValueError),
+
+ # Struct
+ ({"s": "a", "i": "1"}, schema, TypeError),
+ (Row(s="a"), schema, ValueError), # Row can't have missing field
+ (Row(s="a", i="1"), schema, TypeError),
+ (["a"], schema, ValueError),
+ (["a", "1"], schema, TypeError),
+ (MyObj(s="a", i="1"), schema, TypeError),
+ (MyObj(s=None, i="1"), schema, ValueError),
+ ]
+
+ # Check success cases
+ for obj, data_type in success_spec:
+ try:
+ _make_type_verifier(data_type, nullable=False)(obj)
+ except Exception:
+ self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))
+
+ # Check failure cases
+ for obj, data_type, exp in failure_spec:
+ msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
+ with self.assertRaises(exp, msg=msg):
+ _make_type_verifier(data_type, nullable=False)(obj)
+
+
+@unittest.skipIf(not _have_arrow, "Arrow not installed")
+class ArrowTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+ cls.spark.conf.set("spark.sql.execution.arrow.enable", "true")
+ cls.schema = StructType([
+ StructField("1_str_t", StringType(), True),
+ StructField("2_int_t", IntegerType(), True),
+ StructField("3_long_t", LongType(), True),
+ StructField("4_float_t", FloatType(), True),
+ StructField("5_double_t", DoubleType(), True)])
+ cls.data = [("a", 1, 10, 0.2, 2.0),
+ ("b", 2, 20, 0.4, 4.0),
+ ("c", 3, 30, 0.8, 6.0)]
+
+ def assertFramesEqual(self, df_with_arrow, df_without):
+ msg = ("DataFrame from Arrow is not equal" +
+ ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
+ ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
+ self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
+
+ def test_unsupported_datatype(self):
+ schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
+ df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
+ with QuietTest(self.sc):
+ self.assertRaises(Exception, lambda: df.toPandas())
+
+ def test_null_conversion(self):
+ df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
+ self.data)
+ pdf = df_null.toPandas()
+ null_counts = pdf.isnull().sum().tolist()
+ self.assertTrue(all([c == 1 for c in null_counts]))
+
+ def test_toPandas_arrow_toggle(self):
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ self.spark.conf.set("spark.sql.execution.arrow.enable", "false")
+ pdf = df.toPandas()
+ self.spark.conf.set("spark.sql.execution.arrow.enable", "true")
+ pdf_arrow = df.toPandas()
+ self.assertFramesEqual(pdf_arrow, pdf)
+
+ def test_pandas_round_trip(self):
+ import pandas as pd
+ import numpy as np
+ data_dict = {}
+ for j, name in enumerate(self.schema.names):
+ data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
+ # need to convert these to numpy types first
+ data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
+ data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
+ pdf = pd.DataFrame(data=data_dict)
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ pdf_arrow = df.toPandas()
+ self.assertFramesEqual(pdf_arrow, pdf)
+
+ def test_filtered_frame(self):
+ df = self.spark.range(3).toDF("i")
+ pdf = df.filter("i < 0").toPandas()
+ self.assertEqual(len(pdf.columns), 1)
+ self.assertEqual(pdf.columns[0], "i")
+ self.assertTrue(pdf.empty)
+
+
if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 26b54a7fb370..22fa273fc1aa 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -32,6 +32,7 @@
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass
+from pyspark import SparkContext
from pyspark.serializers import CloudPickleSerializer
__all__ = [
@@ -727,18 +728,6 @@ def __eq__(self, other):
_BRACKETS = {'(': ')', '[': ']', '{': '}'}
-def _parse_basic_datatype_string(s):
- if s in _all_atomic_types.keys():
- return _all_atomic_types[s]()
- elif s == "int":
- return IntegerType()
- elif _FIXED_DECIMAL.match(s):
- m = _FIXED_DECIMAL.match(s)
- return DecimalType(int(m.group(1)), int(m.group(2)))
- else:
- raise ValueError("Could not parse datatype: %s" % s)
-
-
def _ignore_brackets_split(s, separator):
"""
Splits the given string by given separator, but ignore separators inside brackets pairs, e.g.
@@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator):
return parts
-def _parse_struct_fields_string(s):
- parts = _ignore_brackets_split(s, ",")
- fields = []
- for part in parts:
- name_and_type = _ignore_brackets_split(part, ":")
- if len(name_and_type) != 2:
- raise ValueError("The strcut field string format is: 'field_name:field_type', " +
- "but got: %s" % part)
- field_name = name_and_type[0].strip()
- field_type = _parse_datatype_string(name_and_type[1])
- fields.append(StructField(field_name, field_type))
- return StructType(fields)
-
-
def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
- for :class:`IntegerType`.
+ for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
+ string and case-insensitive strings.
>>> _parse_datatype_string("int ")
IntegerType
+ >>> _parse_datatype_string("INT ")
+ IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
+ >>> _parse_datatype_string("a DOUBLE, b STRING")
+ StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map ")
@@ -806,43 +786,43 @@ def _parse_datatype_string(s):
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
>>> _parse_datatype_string("array>> _parse_datatype_string("map>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- ValueError:...
+ ParseException:...
"""
- s = s.strip()
- if s.startswith("array<"):
- if s[-1] != ">":
- raise ValueError("'>' should be the last char, but got: %s" % s)
- return ArrayType(_parse_datatype_string(s[6:-1]))
- elif s.startswith("map<"):
- if s[-1] != ">":
- raise ValueError("'>' should be the last char, but got: %s" % s)
- parts = _ignore_brackets_split(s[4:-1], ",")
- if len(parts) != 2:
- raise ValueError("The map type string format is: 'map', " +
- "but got: %s" % s)
- kt = _parse_datatype_string(parts[0])
- vt = _parse_datatype_string(parts[1])
- return MapType(kt, vt)
- elif s.startswith("struct<"):
- if s[-1] != ">":
- raise ValueError("'>' should be the last char, but got: %s" % s)
- return _parse_struct_fields_string(s[7:-1])
- elif ":" in s:
- return _parse_struct_fields_string(s)
- else:
- return _parse_basic_datatype_string(s)
+ sc = SparkContext._active_spark_context
+
+ def from_ddl_schema(type_str):
+ return _parse_datatype_json_string(
+ sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
+
+ def from_ddl_datatype(type_str):
+ return _parse_datatype_json_string(
+ sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
+
+ try:
+ # DDL format, "fieldname datatype, fieldname datatype".
+ return from_ddl_schema(s)
+ except Exception as e:
+ try:
+ # For backwards compatibility, "integer", "struct" and etc.
+ return from_ddl_datatype(s)
+ except:
+ try:
+ # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
+ return from_ddl_datatype("struct<%s>" % s.strip())
+ except:
+ raise e
def _parse_datatype_json_string(json_string):
@@ -1249,121 +1229,196 @@ def _infer_schema_type(obj, dataType):
}
-def _verify_type(obj, dataType, nullable=True):
+def _make_type_verifier(dataType, nullable=True, name=None):
"""
- Verify the type of obj against dataType, raise a TypeError if they do not match.
-
- Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed
- range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it
- will become infinity when cast to Java float if it overflows.
-
- >>> _verify_type(None, StructType([]))
- >>> _verify_type("", StringType())
- >>> _verify_type(0, LongType())
- >>> _verify_type(list(range(3)), ArrayType(ShortType()))
- >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Make a verifier that checks the type of obj against dataType and raises a TypeError if they do
+ not match.
+
+ This verifier also checks the value of obj against datatype and raises a ValueError if it's not
+ within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is
+ not checked, so it will become infinity when cast to Java float if it overflows.
+
+ >>> _make_type_verifier(StructType([]))(None)
+ >>> _make_type_verifier(StringType())("")
+ >>> _make_type_verifier(LongType())(0)
+ >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
+ >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
- >>> _verify_type({}, MapType(StringType(), IntegerType()))
- >>> _verify_type((), StructType([]))
- >>> _verify_type([], StructType([]))
- >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
+ >>> _make_type_verifier(StructType([]))(())
+ >>> _make_type_verifier(StructType([]))([])
+ >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> # Check if numeric values are within the allowed range.
- >>> _verify_type(12, ByteType())
- >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> _make_type_verifier(ByteType())(12)
+ >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
- >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
- >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> _make_type_verifier(
+ ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
- >>> _verify_type({None: 1}, MapType(StringType(), IntegerType()))
+ >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
Traceback (most recent call last):
...
ValueError:...
>>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
- >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
- if obj is None:
- if nullable:
- return
- else:
- raise ValueError("This field is not nullable, but got None")
- # StringType can work with any types
- if isinstance(dataType, StringType):
- return
+ if name is None:
+ new_msg = lambda msg: msg
+ new_name = lambda n: "field %s" % n
+ else:
+ new_msg = lambda msg: "%s: %s" % (name, msg)
+ new_name = lambda n: "field %s in %s" % (n, name)
- if isinstance(dataType, UserDefinedType):
- if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
- raise ValueError("%r is not an instance of type %r" % (obj, dataType))
- _verify_type(dataType.toInternal(obj), dataType.sqlType())
- return
+ def verify_nullability(obj):
+ if obj is None:
+ if nullable:
+ return True
+ else:
+ raise ValueError(new_msg("This field is not nullable, but got None"))
+ else:
+ return False
_type = type(dataType)
- assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj)
- if _type is StructType:
- # check the type and fields later
- pass
- else:
+ def assert_acceptable_types(obj):
+ assert _type in _acceptable_types, \
+ new_msg("unknown datatype: %s for object %r" % (dataType, obj))
+
+ def verify_acceptable_types(obj):
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj)))
+ raise TypeError(new_msg("%s can not accept object %r in type %s"
+ % (dataType, obj, type(obj))))
+
+ if isinstance(dataType, StringType):
+ # StringType can work with any types
+ verify_value = lambda _: _
+
+ elif isinstance(dataType, UserDefinedType):
+ verifier = _make_type_verifier(dataType.sqlType(), name=name)
- if isinstance(dataType, ByteType):
- if obj < -128 or obj > 127:
- raise ValueError("object of ByteType out of range, got: %s" % obj)
+ def verify_udf(obj):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+ raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType)))
+ verifier(dataType.toInternal(obj))
+
+ verify_value = verify_udf
+
+ elif isinstance(dataType, ByteType):
+ def verify_byte(obj):
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ if obj < -128 or obj > 127:
+ raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj))
+
+ verify_value = verify_byte
elif isinstance(dataType, ShortType):
- if obj < -32768 or obj > 32767:
- raise ValueError("object of ShortType out of range, got: %s" % obj)
+ def verify_short(obj):
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ if obj < -32768 or obj > 32767:
+ raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj))
+
+ verify_value = verify_short
elif isinstance(dataType, IntegerType):
- if obj < -2147483648 or obj > 2147483647:
- raise ValueError("object of IntegerType out of range, got: %s" % obj)
+ def verify_integer(obj):
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ if obj < -2147483648 or obj > 2147483647:
+ raise ValueError(
+ new_msg("object of IntegerType out of range, got: %s" % obj))
+
+ verify_value = verify_integer
elif isinstance(dataType, ArrayType):
- for i in obj:
- _verify_type(i, dataType.elementType, dataType.containsNull)
+ element_verifier = _make_type_verifier(
+ dataType.elementType, dataType.containsNull, name="element in array %s" % name)
+
+ def verify_array(obj):
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ for i in obj:
+ element_verifier(i)
+
+ verify_value = verify_array
elif isinstance(dataType, MapType):
- for k, v in obj.items():
- _verify_type(k, dataType.keyType, False)
- _verify_type(v, dataType.valueType, dataType.valueContainsNull)
+ key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name)
+ value_verifier = _make_type_verifier(
+ dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name)
+
+ def verify_map(obj):
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ for k, v in obj.items():
+ key_verifier(k)
+ value_verifier(v)
+
+ verify_value = verify_map
elif isinstance(dataType, StructType):
- if isinstance(obj, dict):
- for f in dataType.fields:
- _verify_type(obj.get(f.name), f.dataType, f.nullable)
- elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
- # the order in obj could be different than dataType.fields
- for f in dataType.fields:
- _verify_type(obj[f.name], f.dataType, f.nullable)
- elif isinstance(obj, (tuple, list)):
- if len(obj) != len(dataType.fields):
- raise ValueError("Length of object (%d) does not match with "
- "length of fields (%d)" % (len(obj), len(dataType.fields)))
- for v, f in zip(obj, dataType.fields):
- _verify_type(v, f.dataType, f.nullable)
- elif hasattr(obj, "__dict__"):
- d = obj.__dict__
- for f in dataType.fields:
- _verify_type(d.get(f.name), f.dataType, f.nullable)
- else:
- raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj)))
+ verifiers = []
+ for f in dataType.fields:
+ verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name))
+ verifiers.append((f.name, verifier))
+
+ def verify_struct(obj):
+ assert_acceptable_types(obj)
+
+ if isinstance(obj, dict):
+ for f, verifier in verifiers:
+ verifier(obj.get(f))
+ elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
+ # the order in obj could be different than dataType.fields
+ for f, verifier in verifiers:
+ verifier(obj[f])
+ elif isinstance(obj, (tuple, list)):
+ if len(obj) != len(verifiers):
+ raise ValueError(
+ new_msg("Length of object (%d) does not match with "
+ "length of fields (%d)" % (len(obj), len(verifiers))))
+ for v, (_, verifier) in zip(obj, verifiers):
+ verifier(v)
+ elif hasattr(obj, "__dict__"):
+ d = obj.__dict__
+ for f, verifier in verifiers:
+ verifier(d.get(f))
+ else:
+ raise TypeError(new_msg("StructType can not accept object %r in type %s"
+ % (obj, type(obj))))
+ verify_value = verify_struct
+
+ else:
+ def verify_default(obj):
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+
+ verify_value = verify_default
+
+ def verify(obj):
+ if not verify_nullability(obj):
+ verify_value(obj)
+
+ return verify
# This is used to unpickle a Row from JVM
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index ef16eaa30afa..11526824da8c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1019,14 +1019,22 @@ def test_histogram(self):
self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))
- def test_repartitionAndSortWithinPartitions(self):
+ def test_repartitionAndSortWithinPartitions_asc(self):
rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
- repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
+ repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True)
partitions = repartitioned.glom().collect()
self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
+ def test_repartitionAndSortWithinPartitions_desc(self):
+ rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
+
+ repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False)
+ partitions = repartitioned.glom().collect()
+ self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)])
+ self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)])
+
def test_repartition_no_skewed(self):
num_partitions = 20
a = self.sc.parallelize(range(int(1000)), 2)
diff --git a/python/run-tests.py b/python/run-tests.py
index b2e50435bb19..afd3d29a0ff9 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -54,7 +54,8 @@ def print_red(text):
LOGGER = logging.getLogger()
# Find out where the assembly jars are located.
-for scala in ["2.11", "2.10"]:
+# Later, add back 2.12 to this list:
+for scala in ["2.11"]:
build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala)
if os.path.isdir(build_dir):
SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*")
diff --git a/python/setup.py b/python/setup.py
index 2644d3e79dea..cfc83c68e3df 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -194,7 +194,7 @@ def _supports_symlinks():
'pyspark.examples.src.main.python': ['*.py', '*/*.py']},
scripts=scripts,
license='http://www.apache.org/licenses/LICENSE-2.0',
- install_requires=['py4j==0.10.4'],
+ install_requires=['py4j==0.10.6'],
setup_requires=['pypandoc'],
extras_require={
'ml': ['numpy>=1.7'],
diff --git a/repl/pom.xml b/repl/pom.xml
index 6d133a3cfff7..51eb9b60dd54 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -32,8 +32,8 @@
repl
- scala-2.10/src/main/scala
- scala-2.10/src/test/scala
+ scala-2.11/src/main/scala
+ scala-2.11/src/test/scala
@@ -71,7 +71,7 @@
${scala.version}
- ${jline.groupid}
+ jline
jline
@@ -170,23 +170,17 @@
+
+
+
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
deleted file mode 100644
index be9b79021d2a..000000000000
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.repl
-
-import scala.tools.nsc.{CompilerCommand, Settings}
-
-import org.apache.spark.annotation.DeveloperApi
-
-/**
- * Command class enabling Spark-specific command line options (provided by
- * org.apache.spark.repl.SparkRunnerSettings).
- *
- * @example new SparkCommandLine(Nil).settings
- *
- * @param args The list of command line arguments
- * @param settings The underlying settings to associate with this set of
- * command-line options
- */
-@DeveloperApi
-class SparkCommandLine(args: List[String], override val settings: Settings)
- extends CompilerCommand(args, settings) {
- def this(args: List[String], error: String => Unit) {
- this(args, new SparkRunnerSettings(error))
- }
-
- def this(args: List[String]) {
- // scalastyle:off println
- this(args, str => Console.println("Error: " + str))
- // scalastyle:on println
- }
-}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
deleted file mode 100644
index 2b5d56a89590..000000000000
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
+++ /dev/null
@@ -1,114 +0,0 @@
-// scalastyle:off
-
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author Paul Phillips
- */
-
-package org.apache.spark.repl
-
-import scala.tools.nsc._
-import scala.tools.nsc.interpreter._
-
-import scala.reflect.internal.util.BatchSourceFile
-import scala.tools.nsc.ast.parser.Tokens.EOF
-
-import org.apache.spark.internal.Logging
-
-private[repl] trait SparkExprTyper extends Logging {
- val repl: SparkIMain
-
- import repl._
- import global.{ reporter => _, Import => _, _ }
- import definitions._
- import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name }
- import naming.freshInternalVarName
-
- object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] {
- def applyRule[T](code: String, rule: UnitParser => T): T = {
- reporter.reset()
- val scanner = newUnitParser(code)
- val result = rule(scanner)
-
- if (!reporter.hasErrors)
- scanner.accept(EOF)
-
- result
- }
-
- def defns(code: String) = stmts(code) collect { case x: DefTree => x }
- def expr(code: String) = applyRule(code, _.expr())
- def stmts(code: String) = applyRule(code, _.templateStats())
- def stmt(code: String) = stmts(code).last // guaranteed nonempty
- }
-
- /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
- def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") {
- var isIncomplete = false
- reporter.withIncompleteHandler((_, _) => isIncomplete = true) {
- val trees = codeParser.stmts(line)
- if (reporter.hasErrors) {
- Some(Nil)
- } else if (isIncomplete) {
- None
- } else {
- Some(trees)
- }
- }
- }
- // def parsesAsExpr(line: String) = {
- // import codeParser._
- // (opt expr line).isDefined
- // }
-
- def symbolOfLine(code: String): Symbol = {
- def asExpr(): Symbol = {
- val name = freshInternalVarName()
- // Typing it with a lazy val would give us the right type, but runs
- // into compiler bugs with things like existentials, so we compile it
- // behind a def and strip the NullaryMethodType which wraps the expr.
- val line = "def " + name + " = {\n" + code + "\n}"
-
- interpretSynthetic(line) match {
- case IR.Success =>
- val sym0 = symbolOfTerm(name)
- // drop NullaryMethodType
- val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType)
- if (sym.info.typeSymbol eq UnitClass) NoSymbol else sym
- case _ => NoSymbol
- }
- }
- def asDefn(): Symbol = {
- val old = repl.definedSymbolList.toSet
-
- interpretSynthetic(code) match {
- case IR.Success =>
- repl.definedSymbolList filterNot old match {
- case Nil => NoSymbol
- case sym :: Nil => sym
- case syms => NoSymbol.newOverloaded(NoPrefix, syms)
- }
- case _ => NoSymbol
- }
- }
- beQuietDuring(asExpr()) orElse beQuietDuring(asDefn())
- }
-
- private var typeOfExpressionDepth = 0
- def typeOfExpression(expr: String, silent: Boolean = true): Type = {
- if (typeOfExpressionDepth > 2) {
- logDebug("Terminating typeOfExpression recursion for expression: " + expr)
- return NoType
- }
- typeOfExpressionDepth += 1
- // Don't presently have a good way to suppress undesirable success output
- // while letting errors through, so it is first trying it silently: if there
- // is an error, and errors are desired, then it re-evaluates non-silently
- // to induce the error message.
- try beSilentDuring(symbolOfLine(expr).tpe) match {
- case NoType if !silent => symbolOfLine(expr).tpe // generate error
- case tpe => tpe
- }
- finally typeOfExpressionDepth -= 1
- }
-}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
deleted file mode 100644
index b7237a6ce822..000000000000
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ /dev/null
@@ -1,1145 +0,0 @@
-// scalastyle:off
-
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author Alexander Spoon
- */
-
-package org.apache.spark.repl
-
-
-import java.net.URL
-
-import scala.reflect.io.AbstractFile
-import scala.tools.nsc._
-import scala.tools.nsc.backend.JavaPlatform
-import scala.tools.nsc.interpreter._
-import scala.tools.nsc.interpreter.{Results => IR}
-import Predef.{println => _, _}
-import java.io.{BufferedReader, FileReader}
-import java.net.URI
-import java.util.concurrent.locks.ReentrantLock
-import scala.sys.process.Process
-import scala.tools.nsc.interpreter.session._
-import scala.util.Properties.{jdkHome, javaVersion}
-import scala.tools.util.{Javap}
-import scala.annotation.tailrec
-import scala.collection.mutable.ListBuffer
-import scala.concurrent.ops
-import scala.tools.nsc.util._
-import scala.tools.nsc.interpreter._
-import scala.tools.nsc.io.{File, Directory}
-import scala.reflect.NameTransformer._
-import scala.tools.nsc.util.ScalaClassLoader._
-import scala.tools.util._
-import scala.language.{implicitConversions, existentials, postfixOps}
-import scala.reflect.{ClassTag, classTag}
-import scala.tools.reflect.StdRuntimeTags._
-
-import java.lang.{Class => jClass}
-import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.SparkContext
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.util.Utils
-
-/** The Scala interactive shell. It provides a read-eval-print loop
- * around the Interpreter class.
- * After instantiation, clients should call the main() method.
- *
- * If no in0 is specified, then input will come from the console, and
- * the class will attempt to provide input editing feature such as
- * input history.
- *
- * @author Moez A. Abdel-Gawad
- * @author Lex Spoon
- * @version 1.2
- */
-@DeveloperApi
-class SparkILoop(
- private val in0: Option[BufferedReader],
- protected val out: JPrintWriter,
- val master: Option[String]
-) extends AnyRef with LoopCommands with SparkILoopInit with Logging {
- def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master))
- def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None)
- def this() = this(None, new JPrintWriter(Console.out, true), None)
-
- private var in: InteractiveReader = _ // the input stream from which commands come
-
- // NOTE: Exposed in package for testing
- private[repl] var settings: Settings = _
-
- private[repl] var intp: SparkIMain = _
-
- @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
- @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i
-
- /** Having inherited the difficult "var-ness" of the repl instance,
- * I'm trying to work around it by moving operations into a class from
- * which it will appear a stable prefix.
- */
- private def onIntp[T](f: SparkIMain => T): T = f(intp)
-
- class IMainOps[T <: SparkIMain](val intp: T) {
- import intp._
- import global._
-
- def printAfterTyper(msg: => String) =
- intp.reporter printMessage afterTyper(msg)
-
- /** Strip NullaryMethodType artifacts. */
- private def replInfo(sym: Symbol) = {
- sym.info match {
- case NullaryMethodType(restpe) if sym.isAccessor => restpe
- case info => info
- }
- }
- def echoTypeStructure(sym: Symbol) =
- printAfterTyper("" + deconstruct.show(replInfo(sym)))
-
- def echoTypeSignature(sym: Symbol, verbose: Boolean) = {
- if (verbose) SparkILoop.this.echo("// Type signature")
- printAfterTyper("" + replInfo(sym))
-
- if (verbose) {
- SparkILoop.this.echo("\n// Internal Type structure")
- echoTypeStructure(sym)
- }
- }
- }
- implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp)
-
- /** TODO -
- * -n normalize
- * -l label with case class parameter names
- * -c complete - leave nothing out
- */
- private def typeCommandInternal(expr: String, verbose: Boolean): Result = {
- onIntp { intp =>
- val sym = intp.symbolOfLine(expr)
- if (sym.exists) intp.echoTypeSignature(sym, verbose)
- else ""
- }
- }
-
- // NOTE: Must be public for visibility
- @DeveloperApi
- var sparkContext: SparkContext = _
-
- override def echoCommandMessage(msg: String) {
- intp.reporter printMessage msg
- }
-
- // def isAsync = !settings.Yreplsync.value
- private[repl] def isAsync = false
- // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
- private def history = in.history
-
- /** The context class loader at the time this object was created */
- protected val originalClassLoader = Utils.getContextOrSparkClassLoader
-
- // classpath entries added via :cp
- private var addedClasspath: String = ""
-
- /** A reverse list of commands to replay if the user requests a :replay */
- private var replayCommandStack: List[String] = Nil
-
- /** A list of commands to replay if the user requests a :replay */
- private def replayCommands = replayCommandStack.reverse
-
- /** Record a command for replay should the user request a :replay */
- private def addReplay(cmd: String) = replayCommandStack ::= cmd
-
- private def savingReplayStack[T](body: => T): T = {
- val saved = replayCommandStack
- try body
- finally replayCommandStack = saved
- }
- private def savingReader[T](body: => T): T = {
- val saved = in
- try body
- finally in = saved
- }
-
-
- private def sparkCleanUp() {
- echo("Stopping spark context.")
- intp.beQuietDuring {
- command("sc.stop()")
- }
- }
- /** Close the interpreter and set the var to null. */
- private def closeInterpreter() {
- if (intp ne null) {
- sparkCleanUp()
- intp.close()
- intp = null
- }
- }
-
- class SparkILoopInterpreter extends SparkIMain(settings, out) {
- outer =>
-
- override private[repl] lazy val formatting = new Formatting {
- def prompt = SparkILoop.this.prompt
- }
- override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
- }
-
- /**
- * Constructs a new interpreter.
- */
- protected def createInterpreter() {
- require(settings != null)
-
- if (addedClasspath != "") settings.classpath.append(addedClasspath)
- val addedJars =
- if (Utils.isWindows) {
- // Strip any URI scheme prefix so we can add the correct path to the classpath
- // e.g. file:/C:/my/path.jar -> C:/my/path.jar
- getAddedJars().map { jar => new URI(jar).getPath.stripPrefix("/") }
- } else {
- // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20).
- getAddedJars().map { jar => new URI(jar).getPath }
- }
- // work around for Scala bug
- val totalClassPath = addedJars.foldLeft(
- settings.classpath.value)((l, r) => ClassPath.join(l, r))
- this.settings.classpath.value = totalClassPath
-
- intp = new SparkILoopInterpreter
- }
-
- /** print a friendly help message */
- private def helpCommand(line: String): Result = {
- if (line == "") helpSummary()
- else uniqueCommand(line) match {
- case Some(lc) => echo("\n" + lc.longHelp)
- case _ => ambiguousError(line)
- }
- }
- private def helpSummary() = {
- val usageWidth = commands map (_.usageMsg.length) max
- val formatStr = "%-" + usageWidth + "s %s %s"
-
- echo("All commands can be abbreviated, e.g. :he instead of :help.")
- echo("Those marked with a * have more detailed help, e.g. :help imports.\n")
-
- commands foreach { cmd =>
- val star = if (cmd.hasLongHelp) "*" else " "
- echo(formatStr.format(cmd.usageMsg, star, cmd.help))
- }
- }
- private def ambiguousError(cmd: String): Result = {
- matchingCommands(cmd) match {
- case Nil => echo(cmd + ": no such command. Type :help for help.")
- case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
- }
- Result(true, None)
- }
- private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
- private def uniqueCommand(cmd: String): Option[LoopCommand] = {
- // this lets us add commands willy-nilly and only requires enough command to disambiguate
- matchingCommands(cmd) match {
- case List(x) => Some(x)
- // exact match OK even if otherwise appears ambiguous
- case xs => xs find (_.name == cmd)
- }
- }
- private var fallbackMode = false
-
- private def toggleFallbackMode() {
- val old = fallbackMode
- fallbackMode = !old
- System.setProperty("spark.repl.fallback", fallbackMode.toString)
- echo(s"""
- |Switched ${if (old) "off" else "on"} fallback mode without restarting.
- | If you have defined classes in the repl, it would
- |be good to redefine them incase you plan to use them. If you still run
- |into issues it would be good to restart the repl and turn on `:fallback`
- |mode as first command.
- """.stripMargin)
- }
-
- /** Show the history */
- private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
- override def usage = "[num]"
- def defaultLines = 20
-
- def apply(line: String): Result = {
- if (history eq NoHistory)
- return "No history available."
-
- val xs = words(line)
- val current = history.index
- val count = try xs.head.toInt catch { case _: Exception => defaultLines }
- val lines = history.asStrings takeRight count
- val offset = current - lines.size + 1
-
- for ((line, index) <- lines.zipWithIndex)
- echo("%3d %s".format(index + offset, line))
- }
- }
-
- // When you know you are most likely breaking into the middle
- // of a line being typed. This softens the blow.
- private[repl] def echoAndRefresh(msg: String) = {
- echo("\n" + msg)
- in.redrawLine()
- }
- private[repl] def echo(msg: String) = {
- out println msg
- out.flush()
- }
- private def echoNoNL(msg: String) = {
- out print msg
- out.flush()
- }
-
- /** Search the history */
- private def searchHistory(_cmdline: String) {
- val cmdline = _cmdline.toLowerCase
- val offset = history.index - history.size + 1
-
- for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
- echo("%d %s".format(index + offset, line))
- }
-
- private var currentPrompt = Properties.shellPromptString
-
- /**
- * Sets the prompt string used by the REPL.
- *
- * @param prompt The new prompt string
- */
- @DeveloperApi
- def setPrompt(prompt: String) = currentPrompt = prompt
-
- /**
- * Represents the current prompt string used by the REPL.
- *
- * @return The current prompt string
- */
- @DeveloperApi
- def prompt = currentPrompt
-
- import LoopCommand.{ cmd, nullary }
-
- /** Standard commands */
- private lazy val standardCommands = List(
- cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
- cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
- historyCommand,
- cmd("h?", "", "search the history", searchHistory),
- cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
- cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand),
- cmd("javap", "", "disassemble a file or class name", javapCommand),
- cmd("load", "", "load and interpret a Scala file", loadCommand),
- nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand),
-// nullary("power", "enable power user mode", powerCmd),
- nullary("quit", "exit the repl", () => Result(false, None)),
- nullary("replay", "reset execution and replay all previous commands", replay),
- nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
- shCommand,
- nullary("silent", "disable/enable automatic printing of results", verbosity),
- nullary("fallback", """
- |disable/enable advanced repl changes, these fix some issues but may introduce others.
- |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode),
- cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
- nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
- )
-
- /** Power user commands */
- private lazy val powerCommands: List[LoopCommand] = List(
- // cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
- )
-
- // private def dumpCommand(): Result = {
- // echo("" + power)
- // history.asStrings takeRight 30 foreach echo
- // in.redrawLine()
- // }
- // private def valsCommand(): Result = power.valsDescription
-
- private val typeTransforms = List(
- "scala.collection.immutable." -> "immutable.",
- "scala.collection.mutable." -> "mutable.",
- "scala.collection.generic." -> "generic.",
- "java.lang." -> "jl.",
- "scala.runtime." -> "runtime."
- )
-
- private def importsCommand(line: String): Result = {
- val tokens = words(line)
- val handlers = intp.languageWildcardHandlers ++ intp.importHandlers
- val isVerbose = tokens contains "-v"
-
- handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
- case (handler, idx) =>
- val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
- val imps = handler.implicitSymbols
- val found = tokens filter (handler importsSymbolNamed _)
- val typeMsg = if (types.isEmpty) "" else types.size + " types"
- val termMsg = if (terms.isEmpty) "" else terms.size + " terms"
- val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit"
- val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
- val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
-
- intp.reporter.printMessage("%2d) %-30s %s%s".format(
- idx + 1,
- handler.importString,
- statsMsg,
- foundMsg
- ))
- }
- }
-
- private def implicitsCommand(line: String): Result = onIntp { intp =>
- import intp._
- import global._
-
- def p(x: Any) = intp.reporter.printMessage("" + x)
-
- // If an argument is given, only show a source with that
- // in its name somewhere.
- val args = line split "\\s+"
- val filtered = intp.implicitSymbolsBySource filter {
- case (source, syms) =>
- (args contains "-v") || {
- if (line == "") (source.fullName.toString != "scala.Predef")
- else (args exists (source.name.toString contains _))
- }
- }
-
- if (filtered.isEmpty)
- return "No implicits have been imported other than those in Predef."
-
- filtered foreach {
- case (source, syms) =>
- p("/* " + syms.size + " implicit members imported from " + source.fullName + " */")
-
- // This groups the members by where the symbol is defined
- val byOwner = syms groupBy (_.owner)
- val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) }
-
- sortedOwners foreach {
- case (owner, members) =>
- // Within each owner, we cluster results based on the final result type
- // if there are more than a couple, and sort each cluster based on name.
- // This is really just trying to make the 100 or so implicits imported
- // by default into something readable.
- val memberGroups: List[List[Symbol]] = {
- val groups = members groupBy (_.tpe.finalResultType) toList
- val (big, small) = groups partition (_._2.size > 3)
- val xss = (
- (big sortBy (_._1.toString) map (_._2)) :+
- (small flatMap (_._2))
- )
-
- xss map (xs => xs sortBy (_.name.toString))
- }
-
- val ownerMessage = if (owner == source) " defined in " else " inherited from "
- p(" /* " + members.size + ownerMessage + owner.fullName + " */")
-
- memberGroups foreach { group =>
- group foreach (s => p(" " + intp.symbolDefString(s)))
- p("")
- }
- }
- p("")
- }
- }
-
- private def findToolsJar() = {
- val jdkPath = Directory(jdkHome)
- val jar = jdkPath / "lib" / "tools.jar" toFile;
-
- if (jar isFile)
- Some(jar)
- else if (jdkPath.isDirectory)
- jdkPath.deepFiles find (_.name == "tools.jar")
- else None
- }
- private def addToolsJarToLoader() = {
- val cl = findToolsJar match {
- case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
- case _ => intp.classLoader
- }
- if (Javap.isAvailable(cl)) {
- logDebug(":javap available.")
- cl
- }
- else {
- logDebug(":javap unavailable: no tools.jar at " + jdkHome)
- intp.classLoader
- }
- }
-
- private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) {
- override def tryClass(path: String): Array[Byte] = {
- val hd :: rest = path split '.' toList;
- // If there are dots in the name, the first segment is the
- // key to finding it.
- if (rest.nonEmpty) {
- intp optFlatName hd match {
- case Some(flat) =>
- val clazz = flat :: rest mkString NAME_JOIN_STRING
- val bytes = super.tryClass(clazz)
- if (bytes.nonEmpty) bytes
- else super.tryClass(clazz + MODULE_SUFFIX_STRING)
- case _ => super.tryClass(path)
- }
- }
- else {
- // Look for Foo first, then Foo$, but if Foo$ is given explicitly,
- // we have to drop the $ to find object Foo, then tack it back onto
- // the end of the flattened name.
- def className = intp flatName path
- def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING
-
- val bytes = super.tryClass(className)
- if (bytes.nonEmpty) bytes
- else super.tryClass(moduleName)
- }
- }
- }
- // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
- private lazy val javap =
- try newJavap()
- catch { case _: Exception => null }
-
- // Still todo: modules.
- private def typeCommand(line0: String): Result = {
- line0.trim match {
- case "" => ":type [-v] "
- case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true)
- case s => typeCommandInternal(s, false)
- }
- }
-
- private def warningsCommand(): Result = {
- if (intp.lastWarnings.isEmpty)
- "Can't find any cached warnings."
- else
- intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
- }
-
- private def javapCommand(line: String): Result = {
- if (javap == null)
- ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome)
- else if (javaVersion startsWith "1.7")
- ":javap not yet working with java 1.7"
- else if (line == "")
- ":javap [-lcsvp] [path1 path2 ...]"
- else
- javap(words(line)) foreach { res =>
- if (res.isError) return "Failed: " + res.value
- else res.show()
- }
- }
-
- private def wrapCommand(line: String): Result = {
- def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T"
- onIntp { intp =>
- import intp._
- import global._
-
- words(line) match {
- case Nil =>
- intp.executionWrapper match {
- case "" => "No execution wrapper is set."
- case s => "Current execution wrapper: " + s
- }
- case "clear" :: Nil =>
- intp.executionWrapper match {
- case "" => "No execution wrapper is set."
- case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper."
- }
- case wrapper :: Nil =>
- intp.typeOfExpression(wrapper) match {
- case PolyType(List(targ), MethodType(List(arg), restpe)) =>
- intp setExecutionWrapper intp.pathToTerm(wrapper)
- "Set wrapper to '" + wrapper + "'"
- case tp =>
- failMsg + "\nFound: "
- }
- case _ => failMsg
- }
- }
- }
-
- private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent"
- // private def phaseCommand(name: String): Result = {
- // val phased: Phased = power.phased
- // import phased.NoPhaseName
-
- // if (name == "clear") {
- // phased.set(NoPhaseName)
- // intp.clearExecutionWrapper()
- // "Cleared active phase."
- // }
- // else if (name == "") phased.get match {
- // case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)"
- // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get)
- // }
- // else {
- // val what = phased.parse(name)
- // if (what.isEmpty || !phased.set(what))
- // "'" + name + "' does not appear to represent a valid phase."
- // else {
- // intp.setExecutionWrapper(pathToPhaseWrapper)
- // val activeMessage =
- // if (what.toString.length == name.length) "" + what
- // else "%s (%s)".format(what, name)
-
- // "Active phase is now: " + activeMessage
- // }
- // }
- // }
-
- /**
- * Provides a list of available commands.
- *
- * @return The list of commands
- */
- @DeveloperApi
- def commands: List[LoopCommand] = standardCommands /*++ (
- if (isReplPower) powerCommands else Nil
- )*/
-
- private val replayQuestionMessage =
- """|That entry seems to have slain the compiler. Shall I replay
- |your session? I can re-run each line except the last one.
- |[y/n]
- """.trim.stripMargin
-
- private def crashRecovery(ex: Throwable): Boolean = {
- echo(ex.toString)
- ex match {
- case _: NoSuchMethodError | _: NoClassDefFoundError =>
- echo("\nUnrecoverable error.")
- throw ex
- case _ =>
- def fn(): Boolean =
- try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
- catch { case _: RuntimeException => false }
-
- if (fn()) replay()
- else echo("\nAbandoning crashed session.")
- }
- true
- }
-
- /** The main read-eval-print loop for the repl. It calls
- * command() for each line of input, and stops when
- * command() returns false.
- */
- private def loop() {
- def readOneLine() = {
- out.flush()
- in readLine prompt
- }
- // return false if repl should exit
- def processLine(line: String): Boolean = {
- if (isAsync) {
- if (!awaitInitialized()) return false
- runThunks()
- }
- if (line eq null) false // assume null means EOF
- else command(line) match {
- case Result(false, _) => false
- case Result(_, Some(finalLine)) => addReplay(finalLine) ; true
- case _ => true
- }
- }
- def innerLoop() {
- val shouldContinue = try {
- processLine(readOneLine())
- } catch {case t: Throwable => crashRecovery(t)}
- if (shouldContinue)
- innerLoop()
- }
- innerLoop()
- }
-
- /** interpret all lines from a specified file */
- private def interpretAllFrom(file: File) {
- savingReader {
- savingReplayStack {
- file applyReader { reader =>
- in = SimpleReader(reader, out, false)
- echo("Loading " + file + "...")
- loop()
- }
- }
- }
- }
-
- /** create a new interpreter and replay the given commands */
- private def replay() {
- reset()
- if (replayCommandStack.isEmpty)
- echo("Nothing to replay.")
- else for (cmd <- replayCommands) {
- echo("Replaying: " + cmd) // flush because maybe cmd will have its own output
- command(cmd)
- echo("")
- }
- }
- private def resetCommand() {
- echo("Resetting repl state.")
- if (replayCommandStack.nonEmpty) {
- echo("Forgetting this session history:\n")
- replayCommands foreach echo
- echo("")
- replayCommandStack = Nil
- }
- if (intp.namedDefinedTerms.nonEmpty)
- echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
- if (intp.definedTypes.nonEmpty)
- echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
-
- reset()
- }
-
- private def reset() {
- intp.reset()
- // unleashAndSetPhase()
- }
-
- /** fork a shell and run a command */
- private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
- override def usage = ""
- def apply(line: String): Result = line match {
- case "" => showUsage()
- case _ =>
- val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")"
- intp interpret toRun
- ()
- }
- }
-
- private def withFile(filename: String)(action: File => Unit) {
- val f = File(filename)
-
- if (f.exists) action(f)
- else echo("That file does not exist")
- }
-
- private def loadCommand(arg: String) = {
- var shouldReplay: Option[String] = None
- withFile(arg)(f => {
- interpretAllFrom(f)
- shouldReplay = Some(":load " + arg)
- })
- Result(true, shouldReplay)
- }
-
- private def addAllClasspath(args: Seq[String]): Unit = {
- var added = false
- var totalClasspath = ""
- for (arg <- args) {
- val f = File(arg).normalize
- if (f.exists) {
- added = true
- addedClasspath = ClassPath.join(addedClasspath, f.path)
- totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
- intp.addUrlsToClassPath(f.toURI.toURL)
- sparkContext.addJar(f.toURI.toURL.getPath)
- }
- }
- }
-
- private def addClasspath(arg: String): Unit = {
- val f = File(arg).normalize
- if (f.exists) {
- addedClasspath = ClassPath.join(addedClasspath, f.path)
- intp.addUrlsToClassPath(f.toURI.toURL)
- sparkContext.addJar(f.toURI.toURL.getPath)
- echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString))
- }
- else echo("The path '" + f + "' doesn't seem to exist.")
- }
-
-
- private def powerCmd(): Result = {
- if (isReplPower) "Already in power mode."
- else enablePowerMode(false)
- }
-
- private[repl] def enablePowerMode(isDuringInit: Boolean) = {
- // replProps.power setValue true
- // unleashAndSetPhase()
- // asyncEcho(isDuringInit, power.banner)
- }
- // private def unleashAndSetPhase() {
-// if (isReplPower) {
-// // power.unleash()
-// // Set the phase to "typer"
-// intp beSilentDuring phaseCommand("typer")
-// }
-// }
-
- private def asyncEcho(async: Boolean, msg: => String) {
- if (async) asyncMessage(msg)
- else echo(msg)
- }
-
- private def verbosity() = {
- // val old = intp.printResults
- // intp.printResults = !old
- // echo("Switched " + (if (old) "off" else "on") + " result printing.")
- }
-
- /**
- * Run one command submitted by the user. Two values are returned:
- * (1) whether to keep running, (2) the line to record for replay,
- * if any.
- */
- private[repl] def command(line: String): Result = {
- if (line startsWith ":") {
- val cmd = line.tail takeWhile (x => !x.isWhitespace)
- uniqueCommand(cmd) match {
- case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
- case _ => ambiguousError(cmd)
- }
- }
- else if (intp.global == null) Result(false, None) // Notice failure to create compiler
- else Result(true, interpretStartingWith(line))
- }
-
- private def readWhile(cond: String => Boolean) = {
- Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
- }
-
- private def pasteCommand(): Result = {
- echo("// Entering paste mode (ctrl-D to finish)\n")
- val code = readWhile(_ => true) mkString "\n"
- echo("\n// Exiting paste mode, now interpreting.\n")
- intp interpret code
- ()
- }
-
- private object paste extends Pasted {
- val ContinueString = " | "
- val PromptString = "scala> "
-
- def interpret(line: String): Unit = {
- echo(line.trim)
- intp interpret line
- echo("")
- }
-
- def transcript(start: String) = {
- echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
- apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
- }
- }
- import paste.{ ContinueString, PromptString }
-
- /**
- * Interpret expressions starting with the first line.
- * Read lines until a complete compilation unit is available
- * or until a syntax error has been seen. If a full unit is
- * read, go ahead and interpret it. Return the full string
- * to be recorded for replay, if any.
- */
- private def interpretStartingWith(code: String): Option[String] = {
- // signal completion non-completion input has been received
- in.completion.resetVerbosity()
-
- def reallyInterpret = {
- val reallyResult = intp.interpret(code)
- (reallyResult, reallyResult match {
- case IR.Error => None
- case IR.Success => Some(code)
- case IR.Incomplete =>
- if (in.interactive && code.endsWith("\n\n")) {
- echo("You typed two blank lines. Starting a new command.")
- None
- }
- else in.readLine(ContinueString) match {
- case null =>
- // we know compilation is going to fail since we're at EOF and the
- // parser thinks the input is still incomplete, but since this is
- // a file being read non-interactively we want to fail. So we send
- // it straight to the compiler for the nice error message.
- intp.compileString(code)
- None
-
- case line => interpretStartingWith(code + "\n" + line)
- }
- })
- }
-
- /** Here we place ourselves between the user and the interpreter and examine
- * the input they are ostensibly submitting. We intervene in several cases:
- *
- * 1) If the line starts with "scala> " it is assumed to be an interpreter paste.
- * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
- * on the previous result.
- * 3) If the Completion object's execute returns Some(_), we inject that value
- * and avoid the interpreter, as it's likely not valid scala code.
- */
- if (code == "") None
- else if (!paste.running && code.trim.startsWith(PromptString)) {
- paste.transcript(code)
- None
- }
- else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
- interpretStartingWith(intp.mostRecentVar + code)
- }
- else if (code.trim startsWith "//") {
- // line comment, do nothing
- None
- }
- else
- reallyInterpret._2
- }
-
- // runs :load `file` on any files passed via -i
- private def loadFiles(settings: Settings) = settings match {
- case settings: SparkRunnerSettings =>
- for (filename <- settings.loadfiles.value) {
- val cmd = ":load " + filename
- command(cmd)
- addReplay(cmd)
- echo("")
- }
- case _ =>
- }
-
- /** Tries to create a JLineReader, falling back to SimpleReader:
- * unless settings or properties are such that it should start
- * with SimpleReader.
- */
- private def chooseReader(settings: Settings): InteractiveReader = {
- if (settings.Xnojline.value || Properties.isEmacsShell)
- SimpleReader()
- else try new SparkJLineReader(
- if (settings.noCompletion.value) NoCompletion
- else new SparkJLineCompletion(intp)
- )
- catch {
- case ex @ (_: Exception | _: NoClassDefFoundError) =>
- echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.")
- SimpleReader()
- }
- }
-
- private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
- private val m = u.runtimeMirror(Utils.getSparkClassLoader)
- private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
- u.TypeTag[T](
- m,
- new TypeCreator {
- def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type =
- m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
- })
-
- private def process(settings: Settings): Boolean = savingContextLoader {
- this.settings = settings
- createInterpreter()
-
- // sets in to some kind of reader depending on environmental cues
- in = in0 match {
- case Some(reader) => SimpleReader(reader, out, true)
- case None =>
- // some post-initialization
- chooseReader(settings) match {
- case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x
- case x => x
- }
- }
- lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain]
- // Bind intp somewhere out of the regular namespace where
- // we can get at it in generated code.
- addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain])))
- addThunk({
- import scala.tools.nsc.io._
- import Properties.userHome
- import scala.compat.Platform.EOL
- val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp())
- if (autorun.isDefined) intp.quietRun(autorun.get)
- })
-
- addThunk(printWelcome())
- addThunk(initializeSpark())
-
- // it is broken on startup; go ahead and exit
- if (intp.reporter.hasErrors)
- return false
-
- // This is about the illusion of snappiness. We call initialize()
- // which spins off a separate thread, then print the prompt and try
- // our best to look ready. The interlocking lazy vals tend to
- // inter-deadlock, so we break the cycle with a single asynchronous
- // message to an rpcEndpoint.
- if (isAsync) {
- intp initialize initializedCallback()
- createAsyncListener() // listens for signal to run postInitialization
- }
- else {
- intp.initializeSynchronous()
- postInitialization()
- }
- // printWelcome()
-
- loadFiles(settings)
-
- try loop()
- catch AbstractOrMissingHandler()
- finally closeInterpreter()
-
- true
- }
-
- // NOTE: Must be public for visibility
- @DeveloperApi
- def createSparkSession(): SparkSession = {
- val execUri = System.getenv("SPARK_EXECUTOR_URI")
- val jars = getAddedJars()
- val conf = new SparkConf()
- .setMaster(getMaster())
- .setJars(jars)
- .setIfMissing("spark.app.name", "Spark shell")
- // SparkContext will detect this configuration and register it with the RpcEnv's
- // file server, setting spark.repl.class.uri to the actual URI for executors to
- // use. This is sort of ugly but since executors are started as part of SparkContext
- // initialization in certain cases, there's an initialization order issue that prevents
- // this from being set after SparkContext is instantiated.
- .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath())
- if (execUri != null) {
- conf.set("spark.executor.uri", execUri)
- }
-
- val builder = SparkSession.builder.config(conf)
- val sparkSession = if (SparkSession.hiveClassesArePresent) {
- logInfo("Creating Spark session with Hive support")
- builder.enableHiveSupport().getOrCreate()
- } else {
- logInfo("Creating Spark session")
- builder.getOrCreate()
- }
- sparkContext = sparkSession.sparkContext
- sparkSession
- }
-
- private def getMaster(): String = {
- val master = this.master match {
- case Some(m) => m
- case None =>
- val envMaster = sys.env.get("MASTER")
- val propMaster = sys.props.get("spark.master")
- propMaster.orElse(envMaster).getOrElse("local[*]")
- }
- master
- }
-
- /** process command-line arguments and do as they request */
- def process(args: Array[String]): Boolean = {
- val command = new SparkCommandLine(args.toList, msg => echo(msg))
- def neededHelp(): String =
- (if (command.settings.help.value) command.usageMsg + "\n" else "") +
- (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "")
-
- // if they asked for no help and command is valid, we call the real main
- neededHelp() match {
- case "" => command.ok && process(command.settings)
- case help => echoNoNL(help) ; true
- }
- }
-
- @deprecated("Use `process` instead", "2.9.0")
- private def main(settings: Settings): Unit = process(settings)
-
- @DeveloperApi
- def getAddedJars(): Array[String] = {
- val conf = new SparkConf().setMaster(getMaster())
- val envJars = sys.env.get("ADD_JARS")
- if (envJars.isDefined) {
- logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead")
- }
- val jars = {
- val userJars = Utils.getUserJars(conf, isShell = true)
- if (userJars.isEmpty) {
- envJars.getOrElse("")
- } else {
- userJars.mkString(",")
- }
- }
- Utils.resolveURIs(jars).split(",").filter(_.nonEmpty)
- }
-
-}
-
-object SparkILoop extends Logging {
- implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
- private def echo(msg: String) = Console println msg
-
- // Designed primarily for use by test code: take a String with a
- // bunch of code, and prints out a transcript of what it would look
- // like if you'd just typed it into the repl.
- private[repl] def runForTranscript(code: String, settings: Settings): String = {
- import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-
- stringFromStream { ostream =>
- Console.withOut(ostream) {
- val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
- override def write(str: String) = {
- // completely skip continuation lines
- if (str forall (ch => ch.isWhitespace || ch == '|')) ()
- // print a newline on empty scala prompts
- else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n")
- else super.write(str)
- }
- }
- val input = new BufferedReader(new StringReader(code)) {
- override def readLine(): String = {
- val s = super.readLine()
- // helping out by printing the line being interpreted.
- if (s != null)
- // scalastyle:off println
- output.println(s)
- // scalastyle:on println
- s
- }
- }
- val repl = new SparkILoop(input, output)
-
- if (settings.classpath.isDefault)
- settings.classpath.value = sys.props("java.class.path")
-
- repl.getAddedJars().map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_))
-
- repl process settings
- }
- }
- }
-
- /** Creates an interpreter loop with default settings and feeds
- * the given code to it as input.
- */
- private[repl] def run(code: String, sets: Settings = new Settings): String = {
- import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-
- stringFromStream { ostream =>
- Console.withOut(ostream) {
- val input = new BufferedReader(new StringReader(code))
- val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
- val repl = new ILoop(input, output)
-
- if (sets.classpath.isDefault)
- sets.classpath.value = sys.props("java.class.path")
-
- repl process sets
- }
- }
- }
- private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
-}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
deleted file mode 100644
index 5f0d92bccd80..000000000000
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ /dev/null
@@ -1,168 +0,0 @@
-// scalastyle:off
-
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author Paul Phillips
- */
-
-package org.apache.spark.repl
-
-import scala.tools.nsc._
-import scala.tools.nsc.interpreter._
-
-import scala.tools.nsc.util.stackTraceString
-
-import org.apache.spark.SPARK_VERSION
-
-/**
- * Machinery for the asynchronous initialization of the repl.
- */
-private[repl] trait SparkILoopInit {
- self: SparkILoop =>
-
- /** Print a welcome message */
- def printWelcome() {
- echo("""Welcome to
- ____ __
- / __/__ ___ _____/ /__
- _\ \/ _ \/ _ `/ __/ '_/
- /___/ .__/\_,_/_/ /_/\_\ version %s
- /_/
-""".format(SPARK_VERSION))
- import Properties._
- val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
- versionString, javaVmName, javaVersion)
- echo(welcomeMsg)
- echo("Type in expressions to have them evaluated.")
- echo("Type :help for more information.")
- }
-
- protected def asyncMessage(msg: String) {
- if (isReplInfo || isReplPower)
- echoAndRefresh(msg)
- }
-
- private val initLock = new java.util.concurrent.locks.ReentrantLock()
- private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized
- private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized
- private val initStart = System.nanoTime
-
- private def withLock[T](body: => T): T = {
- initLock.lock()
- try body
- finally initLock.unlock()
- }
- // a condition used to ensure serial access to the compiler.
- @volatile private var initIsComplete = false
- @volatile private var initError: String = null
- private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L)
-
- // the method to be called when the interpreter is initialized.
- // Very important this method does nothing synchronous (i.e. do
- // not try to use the interpreter) because until it returns, the
- // repl's lazy val `global` is still locked.
- protected def initializedCallback() = withLock(initCompilerCondition.signal())
-
- // Spins off a thread which awaits a single message once the interpreter
- // has been initialized.
- protected def createAsyncListener() = {
- io.spawn {
- withLock(initCompilerCondition.await())
- asyncMessage("[info] compiler init time: " + elapsed() + " s.")
- postInitialization()
- }
- }
-
- // called from main repl loop
- protected def awaitInitialized(): Boolean = {
- if (!initIsComplete)
- withLock { while (!initIsComplete) initLoopCondition.await() }
- if (initError != null) {
- // scalastyle:off println
- println("""
- |Failed to initialize the REPL due to an unexpected error.
- |This is a bug, please, report it along with the error diagnostics printed below.
- |%s.""".stripMargin.format(initError)
- )
- // scalastyle:on println
- false
- } else true
- }
- // private def warningsThunks = List(
- // () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _),
- // )
-
- protected def postInitThunks = List[Option[() => Unit]](
- Some(intp.setContextClassLoader _),
- if (isReplPower) Some(() => enablePowerMode(true)) else None
- ).flatten
- // ++ (
- // warningsThunks
- // )
- // called once after init condition is signalled
- protected def postInitialization() {
- try {
- postInitThunks foreach (f => addThunk(f()))
- runThunks()
- } catch {
- case ex: Throwable =>
- initError = stackTraceString(ex)
- throw ex
- } finally {
- initIsComplete = true
-
- if (isAsync) {
- asyncMessage("[info] total init time: " + elapsed() + " s.")
- withLock(initLoopCondition.signal())
- }
- }
- }
-
- def initializeSpark() {
- intp.beQuietDuring {
- command("""
- @transient val spark = org.apache.spark.repl.Main.interp.createSparkSession()
- @transient val sc = {
- val _sc = spark.sparkContext
- if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) {
- val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null)
- if (proxyUrl != null) {
- println(s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}")
- } else {
- println(s"Spark Context Web UI is available at Spark Master Public URL")
- }
- } else {
- _sc.uiWebUrl.foreach {
- webUrl => println(s"Spark context Web UI available at ${webUrl}")
- }
- }
- println("Spark context available as 'sc' " +
- s"(master = ${_sc.master}, app id = ${_sc.applicationId}).")
- println("Spark session available as 'spark'.")
- _sc
- }
- """)
- command("import org.apache.spark.SparkContext._")
- command("import spark.implicits._")
- command("import spark.sql")
- command("import org.apache.spark.sql.functions._")
- }
- }
-
- // code to be executed only after the interpreter is initialized
- // and the lazy val `global` can be accessed without risk of deadlock.
- private var pendingThunks: List[() => Unit] = Nil
- protected def addThunk(body: => Unit) = synchronized {
- pendingThunks :+= (() => body)
- }
- protected def runThunks(): Unit = synchronized {
- if (pendingThunks.nonEmpty)
- logDebug("Clearing " + pendingThunks.size + " thunks.")
-
- while (pendingThunks.nonEmpty) {
- val thunk = pendingThunks.head
- pendingThunks = pendingThunks.tail
- thunk()
- }
- }
-}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
deleted file mode 100644
index 74a04d5a42bb..000000000000
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ /dev/null
@@ -1,1808 +0,0 @@
-// scalastyle:off
-
-/* NSC -- new Scala compiler
- * Copyright 2005-2013 LAMP/EPFL
- * @author Martin Odersky
- */
-
-package org.apache.spark.repl
-
-import java.io.File
-
-import scala.tools.nsc._
-import scala.tools.nsc.backend.JavaPlatform
-import scala.tools.nsc.interpreter._
-
-import Predef.{ println => _, _ }
-import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString}
-import scala.reflect.internal.util._
-import java.net.URL
-import scala.sys.BooleanProp
-import io.{AbstractFile, PlainFile, VirtualDirectory}
-
-import reporters._
-import symtab.Flags
-import scala.reflect.internal.Names
-import scala.tools.util.PathResolver
-import ScalaClassLoader.URLClassLoader
-import scala.tools.nsc.util.Exceptional.unwrap
-import scala.collection.{ mutable, immutable }
-import scala.util.control.Exception.{ ultimately }
-import SparkIMain._
-import java.util.concurrent.Future
-import typechecker.Analyzer
-import scala.language.implicitConversions
-import scala.reflect.runtime.{ universe => ru }
-import scala.reflect.{ ClassTag, classTag }
-import scala.tools.reflect.StdRuntimeTags._
-import scala.util.control.ControlThrowable
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
-import org.apache.spark.annotation.DeveloperApi
-
-// /** directory to save .class files to */
-// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) {
-// private def pp(root: AbstractFile, indentLevel: Int) {
-// val spaces = " " * indentLevel
-// out.println(spaces + root.name)
-// if (root.isDirectory)
-// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1))
-// }
-// // print the contents hierarchically
-// def show() = pp(this, 0)
-// }
-
- /** An interpreter for Scala code.
- *
- * The main public entry points are compile(), interpret(), and bind().
- * The compile() method loads a complete Scala file. The interpret() method
- * executes one line of Scala code at the request of the user. The bind()
- * method binds an object to a variable that can then be used by later
- * interpreted code.
- *
- * The overall approach is based on compiling the requested code and then
- * using a Java classloader and Java reflection to run the code
- * and access its results.
- *
- * In more detail, a single compiler instance is used
- * to accumulate all successfully compiled or interpreted Scala code. To
- * "interpret" a line of code, the compiler generates a fresh object that
- * includes the line of code and which has public member(s) to export
- * all variables defined by that code. To extract the result of an
- * interpreted line to show the user, a second "result object" is created
- * which imports the variables exported by the above object and then
- * exports members called "$eval" and "$print". To accommodate user expressions
- * that read from variables or methods defined in previous statements, "import"
- * statements are used.
- *
- * This interpreter shares the strengths and weaknesses of using the
- * full compiler-to-Java. The main strength is that interpreted code
- * behaves exactly as does compiled code, including running at full speed.
- * The main weakness is that redefining classes and methods is not handled
- * properly, because rebinding at the Java level is technically difficult.
- *
- * @author Moez A. Abdel-Gawad
- * @author Lex Spoon
- */
- @DeveloperApi
- class SparkIMain(
- initialSettings: Settings,
- val out: JPrintWriter,
- propagateExceptions: Boolean = false)
- extends SparkImports with Logging { imain =>
-
- private val conf = new SparkConf()
-
- private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
- /** Local directory to save .class files too */
- private[repl] val outputDir = {
- val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf))
- Utils.createTempDir(root = rootDir, namePrefix = "repl")
- }
- if (SPARK_DEBUG_REPL) {
- echo("Output directory: " + outputDir)
- }
-
- /**
- * Returns the path to the output directory containing all generated
- * class files that will be served by the REPL class server.
- */
- @DeveloperApi
- lazy val getClassOutputDirectory = outputDir
-
- private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
- /** Jetty server that will serve our classes to worker nodes */
- private var currentSettings: Settings = initialSettings
- private var printResults = true // whether to print result lines
- private var totalSilence = false // whether to print anything
- private var _initializeComplete = false // compiler is initialized
- private var _isInitialized: Future[Boolean] = null // set up initialization future
- private var bindExceptions = true // whether to bind the lastException variable
- private var _executionWrapper = "" // code to be wrapped around all lines
-
- /** We're going to go to some trouble to initialize the compiler asynchronously.
- * It's critical that nothing call into it until it's been initialized or we will
- * run into unrecoverable issues, but the perceived repl startup time goes
- * through the roof if we wait for it. So we initialize it with a future and
- * use a lazy val to ensure that any attempt to use the compiler object waits
- * on the future.
- */
- private var _classLoader: AbstractFileClassLoader = null // active classloader
- private val _compiler: Global = newCompiler(settings, reporter) // our private compiler
-
- private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) }
- private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL
-
- private val nextReqId = {
- var counter = 0
- () => { counter += 1 ; counter }
- }
-
- private def compilerClasspath: Seq[URL] = (
- if (isInitializeComplete) global.classPath.asURLs
- else new PathResolver(settings).result.asURLs // the compiler's classpath
- )
- // NOTE: Exposed to repl package since accessed indirectly from SparkIMain
- private[repl] def settings = currentSettings
- private def mostRecentLine = prevRequestList match {
- case Nil => ""
- case req :: _ => req.originalLine
- }
- // Run the code body with the given boolean settings flipped to true.
- private def withoutWarnings[T](body: => T): T = beQuietDuring {
- val saved = settings.nowarn.value
- if (!saved)
- settings.nowarn.value = true
-
- try body
- finally if (!saved) settings.nowarn.value = false
- }
-
- /** construct an interpreter that reports to Console */
- def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
- def this() = this(new Settings())
-
- private lazy val repllog: Logger = new Logger {
- val out: JPrintWriter = imain.out
- val isInfo: Boolean = BooleanProp keyExists "scala.repl.info"
- val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug"
- val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace"
- }
- private[repl] lazy val formatting: Formatting = new Formatting {
- val prompt = Properties.shellPromptString
- }
-
- // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop
- private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this)
-
- /**
- * Determines if errors were reported (typically during compilation).
- *
- * @note This is not for runtime errors
- *
- * @return True if had errors, otherwise false
- */
- @DeveloperApi
- def isReportingErrors = reporter.hasErrors
-
- import formatting._
- import reporter.{ printMessage, withoutTruncating }
-
- // This exists mostly because using the reporter too early leads to deadlock.
- private def echo(msg: String) { Console println msg }
- private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
- private def _initialize() = {
- try {
- // todo. if this crashes, REPL will hang
- new _compiler.Run() compileSources _initSources
- _initializeComplete = true
- true
- }
- catch AbstractOrMissingHandler()
- }
- private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
-
- // argument is a thunk to execute after init is done
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] def initialize(postInitSignal: => Unit) {
- synchronized {
- if (_isInitialized == null) {
- _isInitialized = io.spawn {
- try _initialize()
- finally postInitSignal
- }
- }
- }
- }
-
- /**
- * Initializes the underlying compiler/interpreter in a blocking fashion.
- *
- * @note Must be executed before using SparkIMain!
- */
- @DeveloperApi
- def initializeSynchronous(): Unit = {
- if (!isInitializeComplete) {
- _initialize()
- assert(global != null, global)
- }
- }
- private def isInitializeComplete = _initializeComplete
-
- /** the public, go through the future compiler */
-
- /**
- * The underlying compiler used to generate ASTs and execute code.
- */
- @DeveloperApi
- lazy val global: Global = {
- if (isInitializeComplete) _compiler
- else {
- // If init hasn't been called yet you're on your own.
- if (_isInitialized == null) {
- logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.")
- initialize(())
- }
- // // blocks until it is ; false means catastrophic failure
- if (_isInitialized.get()) _compiler
- else null
- }
- }
- @deprecated("Use `global` for access to the compiler instance.", "2.9.0")
- private lazy val compiler: global.type = global
-
- import global._
- import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember}
- import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass}
-
- private implicit class ReplTypeOps(tp: Type) {
- def orElse(other: => Type): Type = if (tp ne NoType) tp else other
- def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
- }
-
- // TODO: If we try to make naming a lazy val, we run into big time
- // scalac unhappiness with what look like cycles. It has not been easy to
- // reduce, but name resolution clearly takes different paths.
- // NOTE: Exposed to repl package since used by SparkExprTyper
- private[repl] object naming extends {
- val global: imain.global.type = imain.global
- } with Naming {
- // make sure we don't overwrite their unwisely named res3 etc.
- def freshUserTermName(): TermName = {
- val name = newTermName(freshUserVarName())
- if (definedNameMap contains name) freshUserTermName()
- else name
- }
- def isUserTermName(name: Name) = isUserVarName("" + name)
- def isInternalTermName(name: Name) = isInternalVarName("" + name)
- }
- import naming._
-
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] object deconstruct extends {
- val global: imain.global.type = imain.global
- } with StructuredTypeStrings
-
- // NOTE: Exposed to repl package since used by SparkImports
- private[repl] lazy val memberHandlers = new {
- val intp: imain.type = imain
- } with SparkMemberHandlers
- import memberHandlers._
-
- /**
- * Suppresses overwriting print results during the operation.
- *
- * @param body The block to execute
- * @tparam T The return type of the block
- *
- * @return The result from executing the block
- */
- @DeveloperApi
- def beQuietDuring[T](body: => T): T = {
- val saved = printResults
- printResults = false
- try body
- finally printResults = saved
- }
-
- /**
- * Completely masks all output during the operation (minus JVM standard
- * out and error).
- *
- * @param operation The block to execute
- * @tparam T The return type of the block
- *
- * @return The result from executing the block
- */
- @DeveloperApi
- def beSilentDuring[T](operation: => T): T = {
- val saved = totalSilence
- totalSilence = true
- try operation
- finally totalSilence = saved
- }
-
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code))
-
- private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = {
- case t: ControlThrowable => throw t
- case t: Throwable =>
- logDebug(label + ": " + unwrap(t))
- logDebug(stackTraceString(unwrap(t)))
- alt
- }
- /** takes AnyRef because it may be binding a Throwable or an Exceptional */
-
- private def withLastExceptionLock[T](body: => T, alt: => T): T = {
- assert(bindExceptions, "withLastExceptionLock called incorrectly.")
- bindExceptions = false
-
- try beQuietDuring(body)
- catch logAndDiscard("withLastExceptionLock", alt)
- finally bindExceptions = true
- }
-
- /**
- * Contains the code (in string form) representing a wrapper around all
- * code executed by this instance.
- *
- * @return The wrapper code as a string
- */
- @DeveloperApi
- def executionWrapper = _executionWrapper
-
- /**
- * Sets the code to use as a wrapper around all code executed by this
- * instance.
- *
- * @param code The wrapper code as a string
- */
- @DeveloperApi
- def setExecutionWrapper(code: String) = _executionWrapper = code
-
- /**
- * Clears the code used as a wrapper around all code executed by
- * this instance.
- */
- @DeveloperApi
- def clearExecutionWrapper() = _executionWrapper = ""
-
- /** interpreter settings */
- private lazy val isettings = new SparkISettings(this)
-
- /**
- * Instantiates a new compiler used by SparkIMain. Overridable to provide
- * own instance of a compiler.
- *
- * @param settings The settings to provide the compiler
- * @param reporter The reporter to use for compiler output
- *
- * @return The compiler as a Global
- */
- @DeveloperApi
- protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = {
- settings.outputDirs setSingleOutput virtualDirectory
- settings.exposeEmptyPackage.value = true
- new Global(settings, reporter) with ReplGlobal {
- override def toString: String = ""
- }
- }
-
- /**
- * Adds any specified jars to the compile and runtime classpaths.
- *
- * @note Currently only supports jars, not directories
- * @param urls The list of items to add to the compile and runtime classpaths
- */
- @DeveloperApi
- def addUrlsToClassPath(urls: URL*): Unit = {
- new Run // Needed to force initialization of "something" to correctly load Scala classes from jars
- urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution
- updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling
- }
-
- private def updateCompilerClassPath(urls: URL*): Unit = {
- require(!global.forMSIL) // Only support JavaPlatform
-
- val platform = global.platform.asInstanceOf[JavaPlatform]
-
- val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*)
-
- // NOTE: Must use reflection until this is exposed/fixed upstream in Scala
- val fieldSetter = platform.getClass.getMethods
- .find(_.getName.endsWith("currentClassPath_$eq")).get
- fieldSetter.invoke(platform, Some(newClassPath))
-
- // Reload all jars specified into our compiler
- global.invalidateClassPathEntries(urls.map(_.getPath): _*)
- }
-
- private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = {
- // Collect our new jars/directories and add them to the existing set of classpaths
- val allClassPaths = (
- platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++
- urls.map(url => {
- platform.classPath.context.newClassPath(
- if (url.getProtocol == "file") {
- val f = new File(url.getPath)
- if (f.isDirectory)
- io.AbstractFile.getDirectory(f)
- else
- io.AbstractFile.getFile(f)
- } else {
- io.AbstractFile.getURL(url)
- }
- )
- })
- ).distinct
-
- // Combine all of our classpaths (old and new) into one merged classpath
- new MergedClassPath(allClassPaths, platform.classPath.context)
- }
-
- /**
- * Represents the parent classloader used by this instance. Can be
- * overridden to provide alternative classloader.
- *
- * @return The classloader used as the parent loader of this instance
- */
- @DeveloperApi
- protected def parentClassLoader: ClassLoader =
- SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() )
-
- /* A single class loader is used for all commands interpreted by this Interpreter.
- It would also be possible to create a new class loader for each command
- to interpret. The advantages of the current approach are:
-
- - Expressions are only evaluated one time. This is especially
- significant for I/O, e.g. "val x = Console.readLine"
-
- The main disadvantage is:
-
- - Objects, classes, and methods cannot be rebound. Instead, definitions
- shadow the old ones, and old code objects refer to the old
- definitions.
- */
- private def resetClassLoader() = {
- logDebug("Setting new classloader: was " + _classLoader)
- _classLoader = null
- ensureClassLoader()
- }
- private final def ensureClassLoader() {
- if (_classLoader == null)
- _classLoader = makeClassLoader()
- }
-
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] def classLoader: AbstractFileClassLoader = {
- ensureClassLoader()
- _classLoader
- }
- private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) {
- /** Overridden here to try translating a simple name to the generated
- * class name if the original attempt fails. This method is used by
- * getResourceAsStream as well as findClass.
- */
- override protected def findAbstractFile(name: String): AbstractFile = {
- super.findAbstractFile(name) match {
- // deadlocks on startup if we try to translate names too early
- case null if isInitializeComplete =>
- generatedName(name) map (x => super.findAbstractFile(x)) orNull
- case file =>
- file
- }
- }
- }
- private def makeClassLoader(): AbstractFileClassLoader =
- new TranslatingClassLoader(parentClassLoader match {
- case null => ScalaClassLoader fromURLs compilerClasspath
- case p =>
- _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl
- _runtimeClassLoader
- })
-
- private def getInterpreterClassLoader() = classLoader
-
- // Set the current Java "context" class loader to this interpreter's class loader
- // NOTE: Exposed to repl package since used by SparkILoopInit
- private[repl] def setContextClassLoader() = classLoader.setAsContext()
-
- /**
- * Returns the real name of a class based on its repl-defined name.
- *
- * ==Example==
- * Given a simple repl-defined name, returns the real name of
- * the class representing it, e.g. for "Bippy" it may return
- * {{{
- * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy
- * }}}
- *
- * @param simpleName The repl-defined name whose real name to retrieve
- *
- * @return Some real name if the simple name exists, else None
- */
- @DeveloperApi
- def generatedName(simpleName: String): Option[String] = {
- if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING)
- else optFlatName(simpleName)
- }
-
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] def flatName(id: String) = optFlatName(id) getOrElse id
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id)
-
- /**
- * Retrieves all simple names contained in the current instance.
- *
- * @return A list of sorted names
- */
- @DeveloperApi
- def allDefinedNames = definedNameMap.keys.toList.sorted
-
- private def pathToType(id: String): String = pathToName(newTypeName(id))
- // NOTE: Exposed to repl package since used by SparkILoop
- private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id))
-
- /**
- * Retrieves the full code path to access the specified simple name
- * content.
- *
- * @param name The simple name of the target whose path to determine
- *
- * @return The full path used to access the specified target (name)
- */
- @DeveloperApi
- def pathToName(name: Name): String = {
- if (definedNameMap contains name)
- definedNameMap(name) fullPath name
- else name.toString
- }
-
- /** Most recent tree handled which wasn't wholly synthetic. */
- private def mostRecentlyHandledTree: Option[Tree] = {
- prevRequests.reverse foreach { req =>
- req.handlers.reverse foreach {
- case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
- case _ => ()
- }
- }
- None
- }
-
- /** Stubs for work in progress. */
- private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = {
- for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) {
- logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2))
- }
- }
-
- private def handleTermRedefinition(name: TermName, old: Request, req: Request) = {
- for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) {
- // Printing the types here has a tendency to cause assertion errors, like
- // assertion failed: fatal: has owner value x, but a class owner is required
- // so DBG is by-name now to keep it in the family. (It also traps the assertion error,
- // but we don't want to unnecessarily risk hosing the compiler's internal state.)
- logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2))
- }
- }
-
- private def recordRequest(req: Request) {
- if (req == null || referencedNameMap == null)
- return
-
- prevRequests += req
- req.referencedNames foreach (x => referencedNameMap(x) = req)
-
- // warning about serially defining companions. It'd be easy
- // enough to just redefine them together but that may not always
- // be what people want so I'm waiting until I can do it better.
- for {
- name <- req.definedNames filterNot (x => req.definedNames contains x.companionName)
- oldReq <- definedNameMap get name.companionName
- newSym <- req.definedSymbols get name
- oldSym <- oldReq.definedSymbols get name.companionName
- if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }
- } {
- afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym."))
- replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
- }
-
- // Updating the defined name map
- req.definedNames foreach { name =>
- if (definedNameMap contains name) {
- if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req)
- else handleTermRedefinition(name.toTermName, definedNameMap(name), req)
- }
- definedNameMap(name) = req
- }
- }
-
- private def replwarn(msg: => String) {
- if (!settings.nowarnings.value)
- printMessage(msg)
- }
-
- private def isParseable(line: String): Boolean = {
- beSilentDuring {
- try parse(line) match {
- case Some(xs) => xs.nonEmpty // parses as-is
- case None => true // incomplete
- }
- catch { case x: Exception => // crashed the compiler
- replwarn("Exception in isParseable(\"" + line + "\"): " + x)
- false
- }
- }
- }
-
- private def compileSourcesKeepingRun(sources: SourceFile*) = {
- val run = new Run()
- reporter.reset()
- run compileSources sources.toList
- (!reporter.hasErrors, run)
- }
-
- /**
- * Compiles specified source files.
- *
- * @param sources The sequence of source files to compile
- *
- * @return True if successful, otherwise false
- */
- @DeveloperApi
- def compileSources(sources: SourceFile*): Boolean =
- compileSourcesKeepingRun(sources: _*)._1
-
- /**
- * Compiles a string of code.
- *
- * @param code The string of code to compile
- *
- * @return True if successful, otherwise false
- */
- @DeveloperApi
- def compileString(code: String): Boolean =
- compileSources(new BatchSourceFile("