diff --git a/.gitignore b/.gitignore index e1231c7374dc..125c5adb9689 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ R-unit-tests.log R/unit-tests.out R/cran-check.out R/pkg/vignettes/sparkr-vignettes.html +R/pkg/tests/fulltests/Rplots.pdf build/*.jar build/apache-maven* build/scala* diff --git a/LICENSE b/LICENSE index 66a2e8f13295..39fe0dc46238 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.6 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index b7fdae58de45..232f5cf31f31 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -429,6 +429,7 @@ export("structField", "structField.character", "print.structField", "structType", + "structType.character", "structType.jobj", "structType.structField", "print.structType") @@ -465,5 +466,6 @@ S3method(print, summary.GBTRegressionModel) S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) +S3method(structType, character) S3method(structType, jobj) S3method(structType, structField) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3b9d42d6e715..e7a166c3014c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1391,6 +1391,10 @@ setMethod("summarize", }) dapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } + packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) @@ -1408,6 +1412,8 @@ dapplyInternal <- function(x, func, schema) { dataFrame(sdf) } +setClassUnion("characterOrstructType", c("character", "structType")) + #' dapply #' #' Apply a function to each partition of a SparkDataFrame. @@ -1418,10 +1424,11 @@ dapplyInternal <- function(x, func, schema) { #' to each partition will be passed. #' The output of func should be a R data.frame. #' @param schema The schema of the resulting SparkDataFrame after the function is applied. -#' It must match the output of func. +#' It must match the output of func. Since Spark 2.3, the DDL-formatted string +#' is also supported for the schema. #' @family SparkDataFrame functions #' @rdname dapply -#' @aliases dapply,SparkDataFrame,function,structType-method +#' @aliases dapply,SparkDataFrame,function,characterOrstructType-method #' @name dapply #' @seealso \link{dapplyCollect} #' @export @@ -1444,6 +1451,17 @@ dapplyInternal <- function(x, func, schema) { #' y <- cbind(y, y[1] + 1L) #' }, #' schema) +#' +#' # The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, d DOUBLE, c STRING, d INT" +#' df1 <- dapply( +#' df, +#' function(x) { +#' y <- x[x[1] > 1, ] +#' y <- cbind(y, y[1] + 1L) +#' }, +#' schema) +#' #' collect(df1) #' # the result #' # a b c d @@ -1452,7 +1470,7 @@ dapplyInternal <- function(x, func, schema) { #' } #' @note dapply since 2.0.0 setMethod("dapply", - signature(x = "SparkDataFrame", func = "function", schema = "structType"), + signature(x = "SparkDataFrame", func = "function", schema = "characterOrstructType"), function(x, func, schema) { dapplyInternal(x, func, schema) }) @@ -1522,6 +1540,7 @@ setMethod("dapplyCollect", #' @param schema the schema of the resulting SparkDataFrame after the function is applied. #' The schema must match to output of \code{func}. It has to be defined for each #' output column with preferred output column name and corresponding data type. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @return A SparkDataFrame. #' @family SparkDataFrame functions #' @aliases gapply,SparkDataFrame-method @@ -1541,7 +1560,7 @@ setMethod("dapplyCollect", #' #' Here our output contains three columns, the key which is a combination of two #' columns with data types integer and string and the mean which is a double. -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' result <- gapply( #' df, @@ -1550,6 +1569,15 @@ setMethod("dapplyCollect", #' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) #' }, schema) #' +#' The schema also can be specified in a DDL-formatted string. +#' schema <- "a INT, c STRING, avg DOUBLE" +#' result <- gapply( +#' df, +#' c("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, schema) +#' #' We can also group the data and afterwards call gapply on GroupedData. #' For Example: #' gdf <- group_by(df, "a", "c") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index a1f5c4f8cc18..86507f13f038 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,10 +38,10 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. +#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse -#' x Column to DateType or TimestampType. For \code{trunc}, it is the string used -#' for specifying the truncation method. For example, "year", "yyyy", "yy" for +#' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string +#' to use to specify the truncation method. For example, "year", "yyyy", "yy" for #' truncate by year, or "month", "mon", "mm" for truncate by month. #' @param ... additional argument(s). #' @name column_datetime_functions @@ -122,7 +122,7 @@ NULL #' format to. See 'Details'. #' } #' @param y Column to compute on. -#' @param ... additional columns. +#' @param ... additional Columns. #' @name column_string_functions #' @rdname column_string_functions #' @family string functions @@ -167,8 +167,7 @@ NULL #' tmp <- mutate(df, v1 = crc32(df$model), v2 = hash(df$model), #' v3 = hash(df$model, df$mpg), v4 = md5(df$model), #' v5 = sha1(df$model), v6 = sha2(df$model, 256)) -#' head(tmp) -#' } +#' head(tmp)} NULL #' Collection functions for Column operations @@ -190,7 +189,6 @@ NULL #' \dontrun{ #' # Dataframe used throughout this doc #' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) -#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) @@ -200,6 +198,34 @@ NULL #' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))} NULL +#' Window functions for Column operations +#' +#' Window functions defined for \code{Column}. +#' +#' @param x In \code{lag} and \code{lead}, it is the column as a character string or a Column +#' to compute on. In \code{ntile}, it is the number of ntile groups. +#' @param offset In \code{lag}, the number of rows back from the current row from which to obtain +#' a value. In \code{lead}, the number of rows after the current row from which to +#' obtain a value. If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. +#' @param ... additional argument(s). +#' @name column_window_functions +#' @rdname column_window_functions +#' @family window functions +#' @examples +#' \dontrun{ +#' # Dataframe used throughout this doc +#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars)) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' tmp <- mutate(df, dist = over(cume_dist(), ws), dense_rank = over(dense_rank(), ws), +#' lag = over(lag(df$mpg), ws), lead = over(lead(df$mpg, 1), ws), +#' percent_rank = over(percent_rank(), ws), +#' rank = over(rank(), ws), row_number = over(row_number(), ws)) +#' # Get ntile group id (1-4) for hp +#' tmp <- mutate(tmp, ntile = over(ntile(4), ws)) +#' head(tmp)} +NULL + #' @details #' \code{lit}: A new Column is created to represent the literal value. #' If the parameter is a Column, it is returned unchanged. @@ -310,7 +336,8 @@ setMethod("asin", }) #' @details -#' \code{atan}: Computes the tangent inverse of the given value. +#' \code{atan}: Computes the tangent inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. #' #' @rdname column_math_functions #' @export @@ -366,7 +393,7 @@ setMethod("base64", }) #' @details -#' \code{bin}: An expression that returns the string representation of the binary value +#' \code{bin}: Returns the string representation of the binary value #' of the given long column. For example, bin("12") returns "1100". #' #' @rdname column_math_functions @@ -573,7 +600,7 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr }) #' @details -#' \code{cos}: Computes the cosine of the given value. +#' \code{cos}: Computes the cosine of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases cos cos,Column-method @@ -694,7 +721,7 @@ setMethod("dayofyear", #' \code{decode}: Computes the first argument into a string from a binary using the provided #' character set. #' -#' @param charset Character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", +#' @param charset character set to use (one of "US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", #' "UTF-16LE", "UTF-16"). #' #' @rdname column_string_functions @@ -827,7 +854,7 @@ setMethod("hex", }) #' @details -#' \code{hour}: Extracts the hours as an integer from a given date/timestamp/string. +#' \code{hour}: Extracts the hour as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases hour hour,Column-method @@ -1149,7 +1176,7 @@ setMethod("min", }) #' @details -#' \code{minute}: Extracts the minutes as an integer from a given date/timestamp/string. +#' \code{minute}: Extracts the minute as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases minute minute,Column-method @@ -1326,7 +1353,7 @@ setMethod("sd", }) #' @details -#' \code{second}: Extracts the seconds as an integer from a given date/timestamp/string. +#' \code{second}: Extracts the second as an integer from a given date/timestamp/string. #' #' @rdname column_datetime_functions #' @aliases second second,Column-method @@ -1381,7 +1408,7 @@ setMethod("sign", signature(x = "Column"), }) #' @details -#' \code{sin}: Computes the sine of the given value. +#' \code{sin}: Computes the sine of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases sin sin,Column-method @@ -1436,20 +1463,18 @@ setMethod("soundex", column(jc) }) -#' Return the partition ID as a column -#' -#' Return the partition ID as a SparkDataFrame column. +#' @details +#' \code{spark_partition_id}: Returns the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. +#' This is equivalent to the \code{SPARK_PARTITION_ID} function in SQL. #' -#' This is equivalent to the SPARK_PARTITION_ID function in SQL. -#' -#' @rdname spark_partition_id -#' @name spark_partition_id -#' @aliases spark_partition_id,missing-method +#' @rdname column_nonaggregate_functions +#' @aliases spark_partition_id spark_partition_id,missing-method #' @export #' @examples -#' \dontrun{select(df, spark_partition_id())} +#' +#' \dontrun{head(select(df, spark_partition_id()))} #' @note spark_partition_id since 2.0.0 setMethod("spark_partition_id", signature("missing"), @@ -1573,7 +1598,7 @@ setMethod("sumDistinct", }) #' @details -#' \code{tan}: Computes the tangent of the given value. +#' \code{tan}: Computes the tangent of the given value. Units in radians. #' #' @rdname column_math_functions #' @aliases tan tan,Column-method @@ -1872,7 +1897,7 @@ setMethod("year", #' @details #' \code{atan2}: Returns the angle theta from the conversion of rectangular coordinates -#' (x, y) to polar coordinates (r, theta). +#' (x, y) to polar coordinates (r, theta). Units in radians. #' #' @rdname column_math_functions #' @aliases atan2 atan2,Column-method @@ -2000,7 +2025,7 @@ setMethod("pmod", signature(y = "Column"), column(jc) }) -#' @param rsd maximum estimation error allowed (default = 0.05) +#' @param rsd maximum estimation error allowed (default = 0.05). #' #' @rdname column_aggregate_functions #' @aliases approxCountDistinct,Column-method @@ -2149,8 +2174,9 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' @rdname column_collection_functions #' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. -#' @aliases from_json from_json,Column,structType-method +#' @aliases from_json from_json,Column,characterOrstructType-method #' @export #' @examples #' @@ -2163,10 +2189,15 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' df2 <- sql("SELECT named_struct('name', 'Bob') as people") #' df2 <- mutate(df2, people_json = to_json(df2$people)) #' schema <- structType(structField("name", "string")) -#' head(select(df2, from_json(df2$people_json, schema)))} +#' head(select(df2, from_json(df2$people_json, schema))) +#' head(select(df2, from_json(df2$people_json, "name STRING")))} #' @note from_json since 2.2.0 -setMethod("from_json", signature(x = "Column", schema = "structType"), +setMethod("from_json", signature(x = "Column", schema = "characterOrstructType"), function(x, schema, as.json.array = FALSE, ...) { + if (is.character(schema)) { + schema <- structType(schema) + } + if (as.json.array) { jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", "createArrayType", @@ -2192,8 +2223,8 @@ setMethod("from_json", signature(x = "Column", schema = "structType"), #' @examples #' #' \dontrun{ -#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, 'PST'), -#' to_utc = to_utc_timestamp(df$time, 'PST')) +#' tmp <- mutate(df, from_utc = from_utc_timestamp(df$time, "PST"), +#' to_utc = to_utc_timestamp(df$time, "PST")) #' head(tmp)} #' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), @@ -2227,7 +2258,7 @@ setMethod("instr", signature(y = "Column", x = "character"), #' @details #' \code{next_day}: Given a date column, returns the first date which is later than the value of #' the date column that is on the specified day of the week. For example, -#' \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first Sunday +#' \code{next_day("2015-07-27", "Sunday")} returns 2015-08-02 because that is the first Sunday #' after 2015-07-27. Day of the week parameter is case insensitive, and accepts first three or #' two characters: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' @@ -2267,7 +2298,7 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' tmp <- mutate(df, t1 = add_months(df$time, 1), #' t2 = date_add(df$time, 2), #' t3 = date_sub(df$time, 3), -#' t4 = next_day(df$time, 'Sun')) +#' t4 = next_day(df$time, "Sun")) #' head(tmp)} #' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), @@ -2376,8 +2407,8 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Unigned) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRightUnsigned}: (Unigned) shifts the given value numBits right. If the given value is +#' a long value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRightUnsigned shiftRightUnsigned,Column,numeric-method @@ -2485,14 +2516,13 @@ setMethod("from_unixtime", signature(x = "Column"), column(jc) }) -#' window -#' -#' Bucketize rows into one or more time windows given a timestamp specifying column. Window -#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window +#' @details +#' \code{window}: Bucketizes rows into one or more time windows given a timestamp specifying column. +#' Window starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window #' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in -#' the order of months are not supported. +#' the order of months are not supported. It returns an output column of struct called 'window' +#' by default with the nested columns 'start' and 'end' #' -#' @param x a time Column. Must be of TimestampType. #' @param windowDuration a string specifying the width of the window, e.g. '1 second', #' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', #' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that @@ -2508,27 +2538,22 @@ setMethod("from_unixtime", signature(x = "Column"), #' window intervals. For example, in order to have hourly tumbling windows #' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide #' \code{startTime} as \code{"15 minutes"}. -#' @param ... further arguments to be passed to or from other methods. -#' @return An output column of struct called 'window' by default with the nested columns 'start' -#' and 'end'. -#' @family date time functions -#' @rdname window -#' @name window -#' @aliases window,Column-method +#' @rdname column_datetime_functions +#' @aliases window window,Column-method #' @export #' @examples -#'\dontrun{ -#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, -#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... -#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, -#' # 09:01:15-09:02:15... -#' window(df$time, "1 minute", startTime = "15 seconds") +#' \dontrun{ +#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, +#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... +#' window(df$time, "1 minute", "15 seconds", "10 seconds") #' -#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... -#' window(df$time, "30 seconds", "10 seconds") -#'} +#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, +#' # 09:01:15-09:02:15... +#' window(df$time, "1 minute", startTime = "15 seconds") +#' +#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' window(df$time, "30 seconds", "10 seconds")} #' @note window since 2.0.0 setMethod("window", signature(x = "Column"), function(x, windowDuration, slideDuration = NULL, startTime = NULL) { @@ -2844,27 +2869,16 @@ setMethod("ifelse", ###################### Window functions###################### -#' cume_dist -#' -#' Window function: returns the cumulative distribution of values within a window partition, -#' i.e. the fraction of rows that are below the current row. -#' -#' N = total number of rows in the partition -#' cume_dist(x) = number of values before (and including) x / N -#' +#' @details +#' \code{cume_dist}: Returns the cumulative distribution of values within a window partition, +#' i.e. the fraction of rows that are below the current row: +#' (number of values before and including x) / (total number of rows in the partition). #' This is equivalent to the \code{CUME_DIST} function in SQL. +#' The method should be used with no argument. #' -#' @rdname cume_dist -#' @name cume_dist -#' @family window functions -#' @aliases cume_dist,missing-method +#' @rdname column_window_functions +#' @aliases cume_dist cume_dist,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(cume_dist(), ws), df$hp, df$am) -#' } #' @note cume_dist since 1.6.0 setMethod("cume_dist", signature("missing"), @@ -2873,28 +2887,19 @@ setMethod("cume_dist", column(jc) }) -#' dense_rank -#' -#' Window function: returns the rank of rows within a window partition, without any gaps. +#' @details +#' \code{dense_rank}: Returns the rank of rows within a window partition, without any gaps. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. -#' #' This is equivalent to the \code{DENSE_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname dense_rank -#' @name dense_rank -#' @family window functions -#' @aliases dense_rank,missing-method +#' @rdname column_window_functions +#' @aliases dense_rank dense_rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(dense_rank(), ws), df$hp, df$am) -#' } #' @note dense_rank since 1.6.0 setMethod("dense_rank", signature("missing"), @@ -2903,34 +2908,15 @@ setMethod("dense_rank", column(jc) }) -#' lag -#' -#' Window function: returns the value that is \code{offset} rows before the current row, and +#' @details +#' \code{lag}: Returns the value that is \code{offset} rows before the current row, and #' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example, #' an \code{offset} of one will return the previous row at any given point in the window partition. -#' #' This is equivalent to the \code{LAG} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows back from the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' @param ... further arguments to be passed to or from other methods. -#' @rdname lag -#' @name lag -#' @aliases lag,characterOrColumn-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases lag lag,characterOrColumn-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lag mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), @@ -2946,35 +2932,16 @@ setMethod("lag", column(jc) }) -#' lead -#' -#' Window function: returns the value that is \code{offset} rows after the current row, and +#' @details +#' \code{lead}: Returns the value that is \code{offset} rows after the current row, and #' \code{defaultValue} if there is less than \code{offset} rows after the current row. #' For example, an \code{offset} of one will return the next row at any given point #' in the window partition. -#' #' This is equivalent to the \code{LEAD} function in SQL. #' -#' @param x the column as a character string or a Column to compute on. -#' @param offset the number of rows after the current row from which to obtain a value. -#' If not specified, the default is 1. -#' @param defaultValue (optional) default to use when the offset row does not exist. -#' -#' @rdname lead -#' @name lead -#' @family window functions -#' @aliases lead,characterOrColumn,numeric-method +#' @rdname column_window_functions +#' @aliases lead lead,characterOrColumn,numeric-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Lead mpg values by 1 row on the partition-and-ordered table -#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) -#' } #' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), @@ -2990,31 +2957,15 @@ setMethod("lead", column(jc) }) -#' ntile -#' -#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window +#' @details +#' \code{ntile}: Returns the ntile group id (from 1 to n inclusive) in an ordered window #' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' #' This is equivalent to the \code{NTILE} function in SQL. #' -#' @param x Number of ntile groups -#' -#' @rdname ntile -#' @name ntile -#' @aliases ntile,numeric-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases ntile ntile,numeric-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' -#' # Partition by am (transmission) and order by hp (horsepower) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' -#' # Get ntile group id (1-4) for hp -#' out <- select(df, over(ntile(4), ws), df$hp, df$am) -#' } #' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), @@ -3023,27 +2974,15 @@ setMethod("ntile", column(jc) }) -#' percent_rank -#' -#' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' -#' This is computed by: -#' -#' (rank of row in its partition - 1) / (number of rows in the partition - 1) -#' -#' This is equivalent to the PERCENT_RANK function in SQL. +#' @details +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). +#' This is equivalent to the \code{PERCENT_RANK} function in SQL. +#' The method should be used with no argument. #' -#' @rdname percent_rank -#' @name percent_rank -#' @family window functions -#' @aliases percent_rank,missing-method +#' @rdname column_window_functions +#' @aliases percent_rank percent_rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(percent_rank(), ws), df$hp, df$am) -#' } #' @note percent_rank since 1.6.0 setMethod("percent_rank", signature("missing"), @@ -3052,29 +2991,19 @@ setMethod("percent_rank", column(jc) }) -#' rank -#' -#' Window function: returns the rank of rows within a window partition. -#' +#' @details +#' \code{rank}: Returns the rank of rows within a window partition. #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. Rank would give me sequential numbers, making #' the person that came in third place (after the ties) would register as coming in fifth. +#' This is equivalent to the \code{RANK} function in SQL. +#' The method should be used with no argument. #' -#' This is equivalent to the RANK function in SQL. -#' -#' @rdname rank -#' @name rank -#' @family window functions -#' @aliases rank,missing-method +#' @rdname column_window_functions +#' @aliases rank rank,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(rank(), ws), df$hp, df$am) -#' } #' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), @@ -3083,11 +3012,7 @@ setMethod("rank", column(jc) }) -# Expose rank() in the R base package -#' @param x a numeric, complex, character or logical vector. -#' @param ... additional argument(s) passed to the method. -#' @name rank -#' @rdname rank +#' @rdname column_window_functions #' @aliases rank,ANY-method #' @export setMethod("rank", @@ -3096,23 +3021,14 @@ setMethod("rank", base::rank(x, ...) }) -#' row_number -#' -#' Window function: returns a sequential number starting at 1 within a window partition. -#' -#' This is equivalent to the ROW_NUMBER function in SQL. +#' @details +#' \code{row_number}: Returns a sequential number starting at 1 within a window partition. +#' This is equivalent to the \code{ROW_NUMBER} function in SQL. +#' The method should be used with no argument. #' -#' @rdname row_number -#' @name row_number -#' @aliases row_number,missing-method -#' @family window functions +#' @rdname column_window_functions +#' @aliases row_number row_number,missing-method #' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(mtcars) -#' ws <- orderBy(windowPartitionBy("am"), "hp") -#' out <- select(df, over(row_number(), ws), df$hp, df$am) -#' } #' @note row_number since 1.6.0 setMethod("row_number", signature("missing"), @@ -3127,7 +3043,7 @@ setMethod("row_number", #' \code{array_contains}: Returns null if the array is null, true if the array contains #' the value, and false otherwise. #' -#' @param value A value to be checked if contained in the column +#' @param value a value to be checked if contained in the column #' @rdname column_collection_functions #' @aliases array_contains array_contains,Column-method #' @export @@ -3172,7 +3088,7 @@ setMethod("size", #' to the natural ordering of the array elements. #' #' @rdname column_collection_functions -#' @param asc A logical flag indicating the sorting order. +#' @param asc a logical flag indicating the sorting order. #' TRUE, sorting is in ascending order. #' FALSE, sorting is in descending order. #' @aliases sort_array sort_array,Column-method @@ -3299,7 +3215,7 @@ setMethod("split_string", #' \code{repeat_string}: Repeats string n times. #' Equivalent to \code{repeat} SQL function. #' -#' @param n Number of repetitions +#' @param n number of repetitions. #' @rdname column_string_functions #' @aliases repeat_string repeat_string,Column-method #' @export @@ -3428,7 +3344,7 @@ setMethod("grouping_bit", #' \code{grouping_id}: Returns the level of grouping. #' Equals to \code{ #' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) -#' } +#' }. #' #' @rdname column_aggregate_functions #' @aliases grouping_id grouping_id,Column-method diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b901b74e4728..92098741f72f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1013,9 +1013,9 @@ setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) #' @name NULL setGeneric("hash", function(x, ...) { standardGeneric("hash") }) -#' @param x empty. Should be used with no argument. -#' @rdname cume_dist +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname column_datetime_diff_functions @@ -1053,9 +1053,9 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @name NULL setGeneric("decode", function(x, charset) { standardGeneric("decode") }) -#' @param x empty. Should be used with no argument. -#' @rdname dense_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname column_string_functions @@ -1159,8 +1159,9 @@ setGeneric("isnan", function(x) { standardGeneric("isnan") }) #' @name NULL setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) -#' @rdname lag +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lag", function(x, ...) { standardGeneric("lag") }) #' @rdname last @@ -1172,8 +1173,9 @@ setGeneric("last", function(x, ...) { standardGeneric("last") }) #' @name NULL setGeneric("last_day", function(x) { standardGeneric("last_day") }) -#' @rdname lead +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("lead", function(x, offset, defaultValue = NULL) { standardGeneric("lead") }) #' @rdname column_nonaggregate_functions @@ -1260,8 +1262,9 @@ setGeneric("not", function(x) { standardGeneric("not") }) #' @name NULL setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) -#' @rdname ntile +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @rdname column_aggregate_functions @@ -1269,9 +1272,9 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @name NULL setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) -#' @param x empty. Should be used with no argument. -#' @rdname percent_rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname column_math_functions @@ -1304,8 +1307,9 @@ setGeneric("rand", function(seed) { standardGeneric("rand") }) #' @name NULL setGeneric("randn", function(seed) { standardGeneric("randn") }) -#' @rdname rank +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("rank", function(x, ...) { standardGeneric("rank") }) #' @rdname column_string_functions @@ -1334,9 +1338,9 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @name NULL setGeneric("rint", function(x) { standardGeneric("rint") }) -#' @param x empty. Should be used with no argument. -#' @rdname row_number +#' @rdname column_window_functions #' @export +#' @name NULL setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname column_string_functions @@ -1414,9 +1418,9 @@ setGeneric("split_string", function(x, pattern) { standardGeneric("split_string" #' @name NULL setGeneric("soundex", function(x) { standardGeneric("soundex") }) -#' @param x empty. Should be used with no argument. -#' @rdname spark_partition_id +#' @rdname column_nonaggregate_functions #' @export +#' @name NULL setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) #' @rdname column_aggregate_functions @@ -1534,8 +1538,9 @@ setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @name NULL setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) -#' @rdname window +#' @rdname column_datetime_functions #' @export +#' @name NULL setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @rdname column_datetime_functions diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 17f5283abead..0a7be0e99397 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -233,6 +233,9 @@ setMethod("gapplyCollect", }) gapplyInternal <- function(x, func, schema) { + if (is.character(schema)) { + schema <- structType(schema) + } packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) broadcastArr <- lapply(ls(.broadcastNames), diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 2f1220a75278..75b1a74ee8c7 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo new("RandomForestRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", @@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("RandomForestClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index cb5bdb90175b..d1ed6833d5d0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -23,18 +23,24 @@ #' Create a structType object that contains the metadata for a SparkDataFrame. Intended for #' use with createDataFrame and toDF. #' -#' @param x a structField object (created with the field() function) +#' @param x a structField object (created with the \code{structField} method). Since Spark 2.3, +#' this can be a DDL-formatted string, which is a comma separated list of field +#' definitions, e.g., "a INT, b STRING". #' @param ... additional structField objects #' @return a structType object #' @rdname structType #' @export #' @examples #'\dontrun{ -#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' schema <- structType(structField("a", "integer"), structField("c", "string"), #' structField("avg", "double")) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) +#' schema <- structType("a INT, c STRING, avg DOUBLE") +#' df1 <- gapply(df, list("a", "c"), +#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, +#' schema) #' } #' @note structType since 1.4.0 structType <- function(x, ...) { @@ -68,6 +74,23 @@ structType.structField <- function(x, ...) { structType(stObj) } +#' @rdname structType +#' @method structType character +#' @export +structType.character <- function(x, ...) { + if (!is.character(x)) { + stop("schema must be a DDL-formatted string.") + } + if (length(list(...)) > 0) { + stop("multiple DDL-formatted strings are not supported") + } + + stObj <- handledCallJStatic("org.apache.spark.sql.types.StructType", + "fromDDL", + x) + structType(stObj) +} + #' Print a Spark StructType. #' #' This function prints the contents of a StructType returned from the @@ -102,7 +125,7 @@ print.structType <- function(x, ...) { #' field1 <- structField("a", "integer") #' field2 <- structField("c", "string") #' field3 <- structField("avg", "double") -#' schema <- structType(field1, field2, field3) +#' schema <- structType(field1, field2, field3) #' df1 <- gapply(df, list("a", "c"), #' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, #' schema) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index f2d2620e5447..81507ea7186a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -113,7 +113,7 @@ sparkR.stop <- function() { #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), #' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.10:2.0.1")) +#' c("com.databricks:spark-avro_2.11:2.0.1")) #'} #' @note sparkR.init since 1.4.0 sparkR.init <- function( @@ -357,7 +357,7 @@ sparkRHive.init <- function(jsc = NULL) { #' sparkR.session("yarn-client", "SparkR", "/home/spark", #' list(spark.executor.memory="4g"), #' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.10:2.0.1")) +#' c("com.databricks:spark-avro_2.11:2.0.1")) #' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g") #'} #' @note sparkR.session since 2.0.0 diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06..2e31dc5f728c 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -30,8 +30,50 @@ port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) +# Waits indefinitely for a socket connecion by default. +selectTimeout <- NULL + while (TRUE) { - ready <- socketSelect(list(inputCon)) + ready <- socketSelect(list(inputCon), timeout = selectTimeout) + + # Note that the children should be terminated in the parent. If each child terminates + # itself, it appears that the resource is not released properly, that causes an unexpected + # termination of this daemon due to, for example, running out of file descriptors + # (see SPARK-21093). Therefore, the current implementation tries to retrieve children + # that are exited (but not terminated) and then sends a kill signal to terminate them properly + # in the parent. + # + # There are two paths that it attempts to send a signal to terminate the children in the parent. + # + # 1. Every second if any socket connection is not available and if there are child workers + # running. + # 2. Right after a socket connection is available. + # + # In other words, the parent attempts to send the signal to the children every second if + # any worker is running or right before launching other worker children from the following + # new socket connection. + + # The process IDs of exited children are returned below. + children <- parallel:::selectChildren(timeout = 0) + + if (is.integer(children)) { + lapply(children, function(child) { + # This should be the PIDs of exited children. Otherwise, this returns raw bytes if any data + # was sent from this child. In this case, we discard it. + pid <- parallel:::readChild(child) + if (is.integer(pid)) { + # This checks if the data from this child is the same pid of this selected child. + if (child == pid) { + # If so, we terminate this child. + tools::pskill(child, tools::SIGUSR1) + } + } + }) + } else if (is.null(children)) { + # If it is NULL, there are no children. Waits indefinitely for a socket connecion. + selectTimeout <- NULL + } + if (ready) { port <- SparkR:::readInt(inputCon) # There is a small chance that it could be interrupted by signal, retry one time @@ -44,12 +86,15 @@ while (TRUE) { } p <- parallel:::mcfork() if (inherits(p, "masterProcess")) { + # Reach here because this is a child process. close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) try(source(script)) - # Set SIGUSR1 so that child can exit - tools::pskill(Sys.getpid(), tools::SIGUSR1) + # Note that this mcexit does not fully terminate this child. parallel:::mcexit(0L) + } else { + # Forking succeeded and we need to check if they finished their jobs every second. + selectTimeout <- 1 } } } diff --git a/R/pkg/tests/fulltests/test_client.R b/R/pkg/tests/fulltests/test_client.R index 0cf25fe1dbf3..de624b572cc2 100644 --- a/R/pkg/tests/fulltests/test_client.R +++ b/R/pkg/tests/fulltests/test_client.R @@ -37,7 +37,7 @@ test_that("multiple packages don't produce a warning", { test_that("sparkJars sparkPackages as character vectors", { args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", - c("com.databricks:spark-avro_2.10:2.0.1")) + c("com.databricks:spark-avro_2.11:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") - expect_match(args, "--packages com.databricks:spark-avro_2.10:2.0.1") + expect_match(args, "--packages com.databricks:spark-avro_2.11:2.0.1") }) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 9b3fc8d270b2..66a0693a59a5 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -212,6 +212,23 @@ test_that("spark.randomForest", { expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, + handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + # spark.randomForest classification can work on libsvm data if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a2bcb5aefe16..77052d4a2834 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -146,6 +146,13 @@ test_that("structType and structField", { expect_is(testSchema, "structType") expect_is(testSchema$fields()[[2]], "structField") expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + testSchema <- structType("a STRING, b INT") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") + + expect_error(structType("A stri"), "DataType stri is not supported.") }) test_that("structField type strings", { @@ -1480,13 +1487,15 @@ test_that("column functions", { j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") df <- as.DataFrame(j) - schema <- structType(structField("age", "integer"), - structField("height", "double")) - s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) - expect_equal(ncol(s), 1) - expect_equal(nrow(s), 3) - expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + schemas <- list(structType(structField("age", "integer"), structField("height", "double")), + "age INT, height DOUBLE") + for (schema in schemas) { + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + } # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) @@ -1504,14 +1513,15 @@ test_that("column functions", { # check if array type in string is correctly supported. jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" df <- as.DataFrame(list(list("people" = jsonArr))) - schema <- structType(structField("name", "string")) - arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) - expect_equal(ncol(arr), 1) - expect_equal(nrow(arr), 1) - expect_is(arr[[1]][[1]], "list") - expect_equal(length(arr$arrcol[[1]]), 2) - expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") - expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + for (schema in list(structType(structField("name", "string")), "name STRING")) { + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + } # Test create_array() and create_map() df <- as.DataFrame(data.frame( @@ -2885,30 +2895,33 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { expect_identical(ldf, result) # Filter and add a column - schema <- structType(structField("a", "integer"), structField("b", "double"), - structField("c", "string"), structField("d", "integer")) - df1 <- dapply( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }, - schema) - result <- collect(df1) - expected <- ldf[ldf$a > 1, ] - expected$d <- expected$a + 1L - rownames(expected) <- NULL - expect_identical(expected, result) - - result <- dapplyCollect( - df, - function(x) { - y <- x[x$a > 1, ] - y <- cbind(y, y$a + 1L) - }) - expected1 <- expected - names(expected1) <- names(result) - expect_identical(expected1, result) + schemas <- list(structType(structField("a", "integer"), structField("b", "double"), + structField("c", "string"), structField("d", "integer")), + "a INT, b DOUBLE, c STRING, d INT") + for (schema in schemas) { + df1 <- dapply( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }, + schema) + result <- collect(df1) + expected <- ldf[ldf$a > 1, ] + expected$d <- expected$a + 1L + rownames(expected) <- NULL + expect_identical(expected, result) + + result <- dapplyCollect( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }) + expected1 <- expected + names(expected1) <- names(result) + expect_identical(expected1, result) + } # Remove the added column df2 <- dapply( @@ -3020,29 +3033,32 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 - schema <- structType(structField("a", "integer"), structField("e", "boolean")) - df2 <- gapply( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - }, - schema) - actual <- collect(df2)$e - expected <- c(TRUE, TRUE) - expect_identical(actual, expected) - - df2Collect <- gapplyCollect( - df, - c(df$"a", df$"c"), - function(key, x) { - y <- data.frame(key[1], sum(x$b) > 2) - colnames(y) <- c("a", "e") - y - }) - actual <- df2Collect$e + schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), + "a INT, e BOOLEAN") + for (schema in schemas) { + df2 <- gapply( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) expect_identical(actual, expected) + df2Collect <- gapplyCollect( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + colnames(y) <- c("a", "e") + y + }) + actual <- df2Collect$e + expect_identical(actual, expected) + } + # Computes the arithmetic mean of the second column by grouping # on the first and third columns. Output the groupping value and the average. schema <- structType(structField("a", "integer"), structField("c", "string"), diff --git a/bin/load-spark-env.cmd b/bin/load-spark-env.cmd index 0977025c2036..993aa31a4c37 100644 --- a/bin/load-spark-env.cmd +++ b/bin/load-spark-env.cmd @@ -35,21 +35,21 @@ if [%SPARK_ENV_LOADED%] == [] ( rem Setting SPARK_SCALA_VERSION if not already set. -set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11" -set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.10" +rem set ASSEMBLY_DIR2="%SPARK_HOME%\assembly\target\scala-2.11" +rem set ASSEMBLY_DIR1="%SPARK_HOME%\assembly\target\scala-2.12" if [%SPARK_SCALA_VERSION%] == [] ( - if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( - echo "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." - echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." - exit 1 - ) - if exist %ASSEMBLY_DIR2% ( + rem if exist %ASSEMBLY_DIR2% if exist %ASSEMBLY_DIR1% ( + rem echo "Presence of build for multiple Scala versions detected." + rem echo "Either clean one of them or, set SPARK_SCALA_VERSION=2.11 in spark-env.cmd." + rem exit 1 + rem ) + rem if exist %ASSEMBLY_DIR2% ( set SPARK_SCALA_VERSION=2.11 - ) else ( - set SPARK_SCALA_VERSION=2.10 - ) + rem ) else ( + rem set SPARK_SCALA_VERSION=2.12 + rem ) ) exit /b 0 diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 8a2f709960a2..9de62039c51e 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -46,18 +46,18 @@ fi if [ -z "$SPARK_SCALA_VERSION" ]; then - ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" - ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.10" + #ASSEMBLY_DIR2="${SPARK_HOME}/assembly/target/scala-2.11" + #ASSEMBLY_DIR1="${SPARK_HOME}/assembly/target/scala-2.12" - if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then - echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 - echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 - exit 1 - fi + #if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + # echo -e "Presence of build for multiple Scala versions detected." 1>&2 + # echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + # exit 1 + #fi - if [ -d "$ASSEMBLY_DIR2" ]; then + #if [ -d "$ASSEMBLY_DIR2" ]; then export SPARK_SCALA_VERSION="2.11" - else - export SPARK_SCALA_VERSION="2.10" - fi + #else + # export SPARK_SCALA_VERSION="2.12" + #fi fi diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8..dd286277c1fc 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.6-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m "$1" + exec "$PYSPARK_DRIVER_PYTHON" -m "$@" exit fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index f211c0873ad2..46d4d5c883cf 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.6-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/build/mvn b/build/mvn index 20f78869a6ce..9beb8af2878a 100755 --- a/build/mvn +++ b/build/mvn @@ -87,13 +87,13 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.11/bin/zinc" + local zinc_path="zinc-0.3.15/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} install_app \ - "${TYPESAFE_MIRROR}/zinc/0.3.11" \ - "zinc-0.3.11.tgz" \ + "${TYPESAFE_MIRROR}/zinc/0.3.15" \ + "zinc-0.3.15.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 066970f24205..0254d0cefc36 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -90,7 +90,8 @@ org.apache.spark spark-tags_${scala.binary.version} - + test + + runtime log4j @@ -2115,9 +2116,9 @@ ${antlr4.version} - ${jline.groupid} + jline jline - ${jline.version} + 2.12.1 org.apache.commons @@ -2135,6 +2136,25 @@ paranamer ${paranamer.version} + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind + + + io.netty + netty-handler + + + @@ -2170,6 +2190,7 @@ --> org.jboss.netty org.codehaus.groovy + *:*_2.10 true @@ -2224,6 +2245,8 @@ -unchecked -deprecation -feature + -explaintypes + -Yno-adapted-args -Xms1024m @@ -2597,6 +2620,7 @@ org.eclipse.jetty:jetty-util org.eclipse.jetty:jetty-server com.google.guava:guava + org.jpmml:* @@ -2611,6 +2635,14 @@ com.google.common ${spark.shade.packageName}.guava + + org.dmg.pmml + ${spark.shade.packageName}.dmg.pmml + + + org.jpmml + ${spark.shade.packageName}.jpmml + @@ -2912,44 +2944,6 @@ - - scala-2.10 - - scala-2.10 - - - 2.10.6 - 2.10 - ${scala.version} - org.scala-lang - - - - - org.apache.maven.plugins - maven-enforcer-plugin - - - enforce-versions - - enforce - - - - - - *:*_2.11 - - - - - - - - - - - test-java-home @@ -2960,16 +2954,18 @@ + scala-2.11 - - !scala-2.10 - + + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4cc4a7fc2eb7..ccbd1e92d5cb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -92,19 +92,11 @@ object SparkBuild extends PomBuild { val projectsMap: mutable.Map[String, Seq[Setting[_]]] = mutable.Map.empty override val profiles = { - val profiles = Properties.propOrNone("sbt.maven.profiles") orElse Properties.envOrNone("SBT_MAVEN_PROFILES") match { + Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => Seq("sbt") case Some(v) => v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - - if (System.getProperty("scala-2.10") == "") { - // To activate scala-2.10 profile, replace empty property value to non-empty value - // in the same way as Maven which handles -Dname as -Dname=true before executes build process. - // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 - System.setProperty("scala-2.10", "true") - } - profiles } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -240,9 +232,7 @@ object SparkBuild extends PomBuild { }, javacJVMVersion := "1.8", - // SBT Scala 2.10 build still doesn't support Java 8, because scalac 2.10 doesn't, but, - // it also doesn't touch Java 8 code and it's OK to emit Java 7 bytecode in this case - scalacJVMVersion := (if (System.getProperty("scala-2.10") == "true") "1.7" else "1.8"), + scalacJVMVersion := "1.8", javacOptions in Compile ++= Seq( "-encoding", "UTF-8", @@ -492,7 +482,6 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", - scalaVersion := "2.10.5", libraryDependencies := allPreviousArtifactKeys.value.flatten ) } @@ -772,13 +761,7 @@ object CopyDependencies { object TestSettings { import BuildCommons._ - private val scalaBinaryVersion = - if (System.getProperty("scala-2.10") == "true") { - "2.10" - } else { - "2.11" - } - + private val scalaBinaryVersion = "2.11" lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, diff --git a/python/README.md b/python/README.md index 0a5c8010b848..84ec88141cb0 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file +At its core PySpark depends on Py4J (currently version 0.10.6), but additional sub-packages have their own requirements (including numpy and pandas). diff --git a/python/docs/Makefile b/python/docs/Makefile index 5e4cfb8ab6fe..09898f29950e 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.6-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip deleted file mode 100644 index 8c3829e32872..000000000000 Binary files a/python/lib/py4j-0.10.4-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.6-src.zip b/python/lib/py4j-0.10.6-src.zip new file mode 100644 index 000000000000..2f8edcc0c0b8 Binary files /dev/null and b/python/lib/py4j-0.10.6-src.zip differ diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 9b345ac73f3d..948806a5c936 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1265,8 +1265,8 @@ def theta(self): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable, - JavaMLReadable): + HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver, + JavaMLWritable, JavaMLReadable): """ Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. @@ -1407,20 +1407,6 @@ def getStepSize(self): """ return self.getOrDefault(self.stepSize) - @since("2.0.0") - def setSolver(self, value): - """ - Sets the value of :py:attr:`solver`. - """ - return self._set(solver=value) - - @since("2.0.0") - def getSolver(self): - """ - Gets the value of solver or its default value. - """ - return self.getOrDefault(self.solver) - @since("2.0.0") def setInitialWeights(self, value): """ diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 77de1cc18246..7eb1b9fac2f5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -314,7 +314,8 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable) @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): """ Maps a column of continuous features to a column of feature buckets. @@ -398,20 +399,6 @@ def getSplits(self): """ return self.getOrDefault(self.splits) - @since("2.1.0") - def setHandleInvalid(self, value): - """ - Sets the value of :py:attr:`handleInvalid`. - """ - return self._set(handleInvalid=value) - - @since("2.1.0") - def getHandleInvalid(self): - """ - Gets the value of :py:attr:`handleInvalid` or its default value. - """ - return self.getOrDefault(self.handleInvalid) - @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -1623,7 +1610,8 @@ def getDegree(self): @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1743,20 +1731,6 @@ def getRelativeError(self): """ return self.getOrDefault(self.relativeError) - @since("2.1.0") - def setHandleInvalid(self, value): - """ - Sets the value of :py:attr:`handleInvalid`. - """ - return self._set(handleInvalid=value) - - @since("2.1.0") - def getHandleInvalid(self): - """ - Gets the value of :py:attr:`handleInvalid` or its default value. - """ - return self.getOrDefault(self.handleInvalid) - def _create_model(self, java_model): """ Private method to convert the java_model to a Python model. @@ -2132,6 +2106,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2971,7 +2951,8 @@ def explainedVariance(self): @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, HasHandleInvalid, + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -3014,6 +2995,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM True >>> loadedRF.getLabelCol() == rf.getLabelCol() True + >>> loadedRF.getHandleInvalid() == rf.getHandleInvalid() + True >>> str(loadedRF) 'RFormula(y ~ x + s) (uid=...)' >>> modelPath = temp_path + "/rFormulaModel" @@ -3052,26 +3035,37 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM "RFormula drops the same category as R when encoding strings.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + + "Options are 'skip' (filter out rows with invalid values), " + + "'error' (throw an error), or 'keep' (put invalid data in a special " + + "additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, formula=None, featuresCol="features", labelCol="label", - forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", + handleInvalid="error"): """ __init__(self, formula=None, featuresCol="features", labelCol="label", \ - forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \ + handleInvalid="error") """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) - self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") + self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", + handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") def setParams(self, formula=None, featuresCol="features", labelCol="label", - forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"): + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", + handleInvalid="error"): """ setParams(self, formula=None, featuresCol="features", labelCol="label", \ - forceIndexLabel=False, stringIndexerOrderType="frequencyDesc") + forceIndexLabel=False, stringIndexerOrderType="frequencyDesc", \ + handleInvalid="error") Sets params for RFormula. """ kwargs = self._input_kwargs diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2d17f95b0c44..f0ff7a5f59ab 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction .. versionadded:: 1.4.0 """ + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, @@ -1371,17 +1374,22 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + "Only applicable to the Tweedie family.", typeConverter=TypeConverters.toFloat) + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: irls.", typeConverter=TypeConverters.toString) + offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " + + "or empty, we treat all instance offsets as 0.0", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, - variancePower=0.0, linkPower=None): + variancePower=0.0, linkPower=None, offsetCol=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ - variancePower=0.0, linkPower=None) + variancePower=0.0, linkPower=None, offsetCol=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1397,12 +1405,12 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, - variancePower=0.0, linkPower=None): + variancePower=0.0, linkPower=None, offsetCol=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ - variancePower=0.0, linkPower=None) + variancePower=0.0, linkPower=None, offsetCol=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1481,6 +1489,20 @@ def getLinkPower(self): """ return self.getOrDefault(self.linkPower) + @since("2.3.0") + def setOffsetCol(self, value): + """ + Sets the value of :py:attr:`offsetCol`. + """ + return self._set(offsetCol=value) + + @since("2.3.0") + def getOffsetCol(self): + """ + Gets the value of offsetCol or its default value. + """ + return self.getOrDefault(self.offsetCol) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a39472e1fe..787004765160 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self): for i in range(0, len(expected)): self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + class HasInducedError(Params): @@ -1270,6 +1291,20 @@ def test_tweedie_distribution(self): self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + def test_offset(self): + + df = self.spark.createDataFrame( + [(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)), + (0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)), + (0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)), + (0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"]) + + glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset") + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581], + atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4)) + class FPGrowthTests(SparkSessionTestCase): def setUp(self): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 873b00fbe534..334761009003 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -608,7 +608,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p sort records by their keys. >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)]) - >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2) + >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, True) >>> rdd2.glom().collect() [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]] """ @@ -627,7 +627,6 @@ def sortPartition(iterator): def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey().first() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ea5e00e9eeef..d5c2a7518b18 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -182,6 +182,23 @@ def loads(self, obj): raise NotImplementedError +class ArrowSerializer(FramedSerializer): + """ + Serializes an Arrow stream. + """ + + def dumps(self, obj): + raise NotImplementedError + + def loads(self, obj): + import pyarrow as pa + reader = pa.RecordBatchFileReader(pa.BufferReader(obj)) + return reader.read_all() + + def __repr__(self): + return "ArrowSerializer" + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 426f07cd9410..c44ab247fd3d 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -232,6 +232,23 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a java UDAF so it can be used in SQL statements. + + :param name: name of the UDAF + :param javaClassName: fully qualified name of java class + + >>> sqlContext.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ @@ -551,6 +568,12 @@ def __init__(self, sqlContext): def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) + def registerJavaFunction(self, name, javaClassName, returnType=None): + self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + + def registerJavaUDAF(self, name, javaClassName): + self.sqlContext.registerJavaUDAF(name, javaClassName) + register.__doc__ = SQLContext.registerFunction.__doc__ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed224..944739bcd207 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,8 @@ from pyspark import copy_func, since from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \ + UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string @@ -833,6 +834,8 @@ def join(self, other, on=None, how=None): else: if how is None: how = "inner" + if on is None: + on = self._jseq([]) assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) @@ -1708,7 +1711,8 @@ def toDF(self, *cols): @since(1.3) def toPandas(self): - """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. + """ + Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. This is only available if Pandas is installed and available. @@ -1721,18 +1725,42 @@ def toPandas(self): 1 5 Bob """ import pandas as pd + if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + try: + import pyarrow + tables = self._collectAsArrow() + if tables: + table = pyarrow.concat_tables(tables) + return table.to_pandas() + else: + return pd.DataFrame.from_records([], columns=self.columns) + except ImportError as e: + msg = "note: pyarrow must be installed and available on calling Python process " \ + "if using spark.sql.execution.arrow.enable=true" + raise ImportError("%s\n%s" % (e.message, msg)) + else: + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type - dtype = {} - for field in self.schema: - pandas_type = _to_corrected_pandas_type(field.dataType) - if pandas_type is not None: - dtype[field.name] = pandas_type + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) - pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf - for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t, copy=False) - return pdf + def _collectAsArrow(self): + """ + Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed + and available. + + .. note:: Experimental. + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.collectAsArrowToPython() + return list(_load_from_socket(port, ArrowSerializer())) ########################################################################################## # Pandas compatibility diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b14fc86fcb39..d2126b279f40 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -67,9 +67,14 @@ def _(): _.__doc__ = 'Window function: ' + doc return _ +_lit_doc = """ + Creates a :class:`Column` of literal value. + >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1) + [Row(height=5, spark_user=True)] + """ _functions = { - 'lit': 'Creates a :class:`Column` of literal value.', + 'lit': _lit_doc, 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', 'asc': 'Returns a sort expression based on the ascending order of the given column name.', @@ -95,10 +100,13 @@ def _(): '0.0 through pi.', 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + '-pi/2 through pi/2.', - 'atan': 'Computes the tangent inverse of the given value.', + 'atan': 'Computes the tangent inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2', 'cbrt': 'Computes the cube-root of the given value.', 'ceil': 'Computes the ceiling of the given value.', - 'cos': 'Computes the cosine of the given value.', + 'cos': """Computes the cosine of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'cosh': 'Computes the hyperbolic cosine of the given value.', 'exp': 'Computes the exponential of the given value.', 'expm1': 'Computes the exponential of the given value minus one.', @@ -109,15 +117,33 @@ def _(): 'rint': 'Returns the double value that is closest in value to the argument and' + ' is equal to a mathematical integer.', 'signum': 'Computes the signum of the given value.', - 'sin': 'Computes the sine of the given value.', + 'sin': """Computes the sine of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'sinh': 'Computes the hyperbolic sine of the given value.', - 'tan': 'Computes the tangent of the given value.', + 'tan': """Computes the tangent of the given value. + + :param col: :class:`DoubleType` column, units in radians.""", 'tanh': 'Computes the hyperbolic tangent of the given value.', - 'toDegrees': '.. note:: Deprecated in 2.1, use degrees instead.', - 'toRadians': '.. note:: Deprecated in 2.1, use radians instead.', + 'toDegrees': '.. note:: Deprecated in 2.1, use :func:`degrees` instead.', + 'toRadians': '.. note:: Deprecated in 2.1, use :func:`radians` instead.', 'bitwiseNOT': 'Computes bitwise not.', } +_collect_list_doc = """ + Aggregate function: returns a list of objects with duplicates. + + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) + >>> df2.agg(collect_list('age')).collect() + [Row(collect_list(age)=[2, 5, 5])] + """ +_collect_set_doc = """ + Aggregate function: returns a set of objects with duplicate elements eliminated. + + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) + >>> df2.agg(collect_set('age')).collect() + [Row(collect_set(age)=[5, 2])] + """ _functions_1_6 = { # unary math functions 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + @@ -131,9 +157,8 @@ def _(): 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', 'skewness': 'Aggregate function: returns the skewness of the values in a group.', 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', - 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', - 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + - ' eliminated.', + 'collect_list': _collect_list_doc, + 'collect_set': _collect_set_doc } _functions_2_1 = { @@ -147,7 +172,7 @@ def _(): # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + - 'polar coordinates (r, theta).', + 'polar coordinates (r, theta). Units in radians.', 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } @@ -200,17 +225,20 @@ def _(): @since(1.3) def approxCountDistinct(col, rsd=None): """ - .. note:: Deprecated in 2.1, use approx_count_distinct instead. + .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ return approx_count_distinct(col, rsd) @since(2.1) def approx_count_distinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. + """Aggregate function: returns a new :class:`Column` for approximate distinct count of column `col`. - >>> df.agg(approx_count_distinct(df.age).alias('c')).collect() - [Row(c=2)] + :param rsd: maximum estimation error allowed (default = 0.05). For rsd < 0.01, it is more + efficient to use :func:`countDistinct` + + >>> df.agg(approx_count_distinct(df.age).alias('distinct_ages')).collect() + [Row(distinct_ages=2)] """ sc = SparkContext._active_spark_context if rsd is None: @@ -267,8 +295,7 @@ def coalesce(*cols): @since(1.6) def corr(col1, col2): - """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` and ``col2``. >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -282,8 +309,7 @@ def corr(col1, col2): @since(2.0) def covar_pop(col1, col2): - """Returns a new :class:`Column` for the population covariance of ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the population covariance of ``col1`` and ``col2``. >>> a = [1] * 10 >>> b = [1] * 10 @@ -297,8 +323,7 @@ def covar_pop(col1, col2): @since(2.0) def covar_samp(col1, col2): - """Returns a new :class:`Column` for the sample covariance of ``col1`` - and ``col2``. + """Returns a new :class:`Column` for the sample covariance of ``col1`` and ``col2``. >>> a = [1] * 10 >>> b = [1] * 10 @@ -450,7 +475,7 @@ def monotonically_increasing_id(): def nanvl(col1, col2): """Returns col1 if it is not NaN, or col2 if col1 is NaN. - Both inputs should be floating point columns (DoubleType or FloatType). + Both inputs should be floating point columns (:class:`DoubleType` or :class:`FloatType`). >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() @@ -460,10 +485,15 @@ def nanvl(col1, col2): return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) +@ignore_unicode_prefix @since(1.4) def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples from U[0.0, 1.0]. + + >>> df.withColumn('rand', rand(seed=42) * 3).collect() + [Row(age=2, name=u'Alice', rand=1.1568609015300986), + Row(age=5, name=u'Bob', rand=1.403379671529166)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -473,10 +503,15 @@ def rand(seed=None): return Column(jc) +@ignore_unicode_prefix @since(1.4) def randn(seed=None): """Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution. + + >>> df.withColumn('randn', randn(seed=42)).collect() + [Row(age=2, name=u'Alice', randn=-0.7556247885860078), + Row(age=5, name=u'Bob', randn=-0.0861619008451133)] """ sc = SparkContext._active_spark_context if seed is not None: @@ -760,7 +795,7 @@ def ntile(n): @since(1.5) def current_date(): """ - Returns the current date as a date column. + Returns the current date as a :class:`DateType` column. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.current_date()) @@ -768,7 +803,7 @@ def current_date(): def current_timestamp(): """ - Returns the current timestamp as a timestamp column. + Returns the current timestamp as a :class:`TimestampType` column. """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.current_timestamp()) @@ -787,8 +822,8 @@ def date_format(date, format): .. note:: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect() [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context @@ -800,8 +835,8 @@ def year(col): """ Extract the year of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(year('a').alias('year')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(year('dt').alias('year')).collect() [Row(year=2015)] """ sc = SparkContext._active_spark_context @@ -813,8 +848,8 @@ def quarter(col): """ Extract the quarter of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(quarter('a').alias('quarter')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(quarter('dt').alias('quarter')).collect() [Row(quarter=2)] """ sc = SparkContext._active_spark_context @@ -826,8 +861,8 @@ def month(col): """ Extract the month of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(month('a').alias('month')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(month('dt').alias('month')).collect() [Row(month=4)] """ sc = SparkContext._active_spark_context @@ -839,8 +874,8 @@ def dayofmonth(col): """ Extract the day of the month of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(dayofmonth('a').alias('day')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofmonth('dt').alias('day')).collect() [Row(day=8)] """ sc = SparkContext._active_spark_context @@ -852,8 +887,8 @@ def dayofyear(col): """ Extract the day of the year of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(dayofyear('a').alias('day')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofyear('dt').alias('day')).collect() [Row(day=98)] """ sc = SparkContext._active_spark_context @@ -865,8 +900,8 @@ def hour(col): """ Extract the hours of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(hour('a').alias('hour')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(hour('ts').alias('hour')).collect() [Row(hour=13)] """ sc = SparkContext._active_spark_context @@ -878,8 +913,8 @@ def minute(col): """ Extract the minutes of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(minute('a').alias('minute')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(minute('ts').alias('minute')).collect() [Row(minute=8)] """ sc = SparkContext._active_spark_context @@ -891,8 +926,8 @@ def second(col): """ Extract the seconds of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) - >>> df.select(second('a').alias('second')).collect() + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['ts']) + >>> df.select(second('ts').alias('second')).collect() [Row(second=15)] """ sc = SparkContext._active_spark_context @@ -904,8 +939,8 @@ def weekofyear(col): """ Extract the week number of a given date as integer. - >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear(df.a).alias('week')).collect() + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(weekofyear(df.dt).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context @@ -917,9 +952,9 @@ def date_add(start, days): """ Returns the date that is `days` days after `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(date_add(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 4, 9))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_add(df.dt, 1).alias('next_date')).collect() + [Row(next_date=datetime.date(2015, 4, 9))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) @@ -930,9 +965,9 @@ def date_sub(start, days): """ Returns the date that is `days` days before `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(date_sub(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 4, 7))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect() + [Row(prev_date=datetime.date(2015, 4, 7))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) @@ -956,9 +991,9 @@ def add_months(start, months): """ Returns the date that is `months` months after `start` - >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) - >>> df.select(add_months(df.d, 1).alias('d')).collect() - [Row(d=datetime.date(2015, 5, 8))] + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + [Row(next_month=datetime.date(2015, 5, 8))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) @@ -969,8 +1004,8 @@ def months_between(date1, date2): """ Returns the number of months between date1 and date2. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) - >>> df.select(months_between(df.t, df.d).alias('months')).collect() + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) + >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() [Row(months=3.9495967...)] """ sc = SparkContext._active_spark_context @@ -1073,12 +1108,19 @@ def last_day(date): return Column(sc._jvm.functions.last_day(_to_java_column(date))) +@ignore_unicode_prefix @since(1.5) def from_unixtime(timestamp, format="yyyy-MM-dd HH:mm:ss"): """ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the timestamp of that moment in the current system time zone in the given format. + + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time']) + >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect() + [Row(ts=u'2015-04-08 00:00:00')] + >>> spark.conf.unset("spark.sql.session.timeZone") """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_unixtime(_to_java_column(timestamp), format)) @@ -1092,6 +1134,12 @@ def unix_timestamp(timestamp=None, format='yyyy-MM-dd HH:mm:ss'): locale, return null if fail. if `timestamp` is None, then it returns current timestamp. + + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect() + [Row(unix_time=1428476400)] + >>> spark.conf.unset("spark.sql.session.timeZone") """ sc = SparkContext._active_spark_context if timestamp is None: @@ -1106,8 +1154,8 @@ def from_utc_timestamp(timestamp, tz): that corresponds to the same time of day in the given timezone. >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() - [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.t, "PST").alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.from_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1119,9 +1167,9 @@ def to_utc_timestamp(timestamp, tz): Given a timestamp, which corresponds to a certain time of day in the given timezone, returns another timestamp that corresponds to the same time of day in UTC. - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() - [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['ts']) + >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) @@ -1839,15 +1887,20 @@ def from_json(col, schema, options={}): string. :param col: string column in json format - :param schema: a StructType or ArrayType of StructType to use when parsing the json column + :param schema: a StructType or ArrayType of StructType to use when parsing the json column. :param options: options to control parsing. accepts the same options as the json datasource + .. note:: Since Spark 2.3, the DDL-formatted string or a JSON format string is also + supported for ``schema``. + >>> from pyspark.sql.types import * >>> data = [(1, '''{"a": 1}''')] >>> schema = StructType([StructField("a", IntegerType())]) >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "a INT").alias("json")).collect() + [Row(json=Row(a=1))] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) @@ -1856,7 +1909,9 @@ def from_json(col, schema, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.from_json(_to_java_column(col), schema.json(), options) + if isinstance(schema, DataType): + schema = schema.json() + jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -1983,15 +2038,25 @@ def __init__(self, func, returnType, name=None): "{0}".format(type(func))) self.func = func - self.returnType = ( - returnType if isinstance(returnType, DataType) - else _parse_datatype_string(returnType)) + self._returnType = returnType # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None self._judf_placeholder = None self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + return self._returnType_placeholder + @property def _judf(self): # It is possible that concurrent access, to newly created UDF, @@ -2096,7 +2161,7 @@ def _test(): sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark - globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() + globs['df'] = spark.createDataFrame([Row(name='Alice', age=2), Row(name='Bob', age=5)]) (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index e3bf0f35ea15..2cc0e2d1d7b8 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -33,7 +33,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ +from pyspark.sql.types import Row, DataType, StringType, StructType, _make_type_verifier, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -514,17 +514,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] - verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): + verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, schema) + verify_func(obj) return obj elif isinstance(schema, DataType): dataType = schema schema = StructType().add("value", schema) + verify_func = _make_type_verifier( + dataType, name="field value") if verifySchema else lambda _: True + def prepare(obj): - verify_func(obj, dataType) + verify_func(obj) return obj, else: if isinstance(schema, list): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b8e..29e48a6ccf76 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -57,13 +57,22 @@ from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * -from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests +from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier +from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException +_have_arrow = False +try: + import pyarrow + _have_arrow = True +except: + # No Arrow, but that's okay, we'll skip those tests + pass + + class UTCOffsetTimezone(datetime.tzinfo): """ Specifies timezone in UTC offset @@ -481,6 +490,16 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + def test_non_existed_udf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + + def test_non_existed_udaf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", @@ -852,7 +871,7 @@ def test_convert_row_to_dict(self): self.assertEqual(1.0, row.asDict()['d']['key'].c) def test_udt(self): - from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier from pyspark.sql.tests import ExamplePointUDT, ExamplePoint def check_datatype(datatype): @@ -868,8 +887,8 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = ExamplePoint(1.0, 2.0) self.assertEqual(_infer_type(p), ExamplePointUDT()) - _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0)) + self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0])) check_datatype(PythonOnlyUDT()) structtype_with_udt = StructType([StructField("label", DoubleType(), False), @@ -877,8 +896,10 @@ def check_datatype(datatype): check_datatype(structtype_with_udt) p = PythonOnlyPoint(1.0, 2.0) self.assertEqual(_infer_type(p), PythonOnlyUDT()) - _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) - self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0)) + self.assertRaises( + ValueError, + lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) def test_simple_udt_in_df(self): schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) @@ -1234,6 +1255,31 @@ def test_struct_type(self): with self.assertRaises(TypeError): not_a_field = struct1[9.9] + def test_parse_datatype_string(self): + from pyspark.sql.types import _all_atomic_types, _parse_datatype_string + for k, t in _all_atomic_types.items(): + if t != NullType: + self.assertEqual(t(), _parse_datatype_string(k)) + self.assertEqual(IntegerType(), _parse_datatype_string("int")) + self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) + self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) + self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) + self.assertEqual( + ArrayType(IntegerType()), + _parse_datatype_string("array")) + self.assertEqual( + MapType(IntegerType(), DoubleType()), + _parse_datatype_string("map< int, double >")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("struct")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a:int, c:double")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a INT, c DOUBLE")) + def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), @@ -2021,6 +2067,22 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.schema.simpleString(), "struct") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_join_without_on(self): + df1 = self.spark.range(1).toDF("a") + df2 = self.spark.range(1).toDF("b") + + try: + self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) + + self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + actual = df1.join(df2, how="inner").collect() + expected = [Row(a=0, b=0)] + self.assertEqual(actual, expected) + finally: + # We should unset this. Otherwise, other tests are affected. + self.spark.conf.unset("spark.sql.crossJoin.enabled") + # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) @@ -2314,6 +2376,12 @@ def test_to_pandas(self): self.assertEquals(types[2], np.bool) self.assertEquals(types[3], np.float32) + def test_create_dataframe_from_array_of_long(self): + import array + data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))] + df = self.spark.createDataFrame(data) + self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807])) + class HiveSparkSubmitTests(SparkSubmitTests): @@ -2620,6 +2688,262 @@ def range_frame_match(): importlib.reload(window) + +class DataTypeVerificationTests(unittest.TestCase): + + def test_verify_type_exception_msg(self): + self.assertRaisesRegexp( + ValueError, + "test_name", + lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None)) + + schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))]) + self.assertRaisesRegexp( + TypeError, + "field b in field a", + lambda: _make_type_verifier(schema)([["data"]])) + + def test_verify_type_ok_nullable(self): + obj = None + types = [IntegerType(), FloatType(), StringType(), StructType([])] + for data_type in types: + try: + _make_type_verifier(data_type, nullable=True)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type)) + + def test_verify_type_not_nullable(self): + import array + import datetime + import decimal + + schema = StructType([ + StructField('s', StringType(), nullable=False), + StructField('i', IntegerType(), nullable=True)]) + + class MyObj: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + # obj, data_type + success_spec = [ + # String + ("", StringType()), + (u"", StringType()), + (1, StringType()), + (1.0, StringType()), + ([], StringType()), + ({}, StringType()), + + # UDT + (ExamplePoint(1.0, 2.0), ExamplePointUDT()), + + # Boolean + (True, BooleanType()), + + # Byte + (-(2**7), ByteType()), + (2**7 - 1, ByteType()), + + # Short + (-(2**15), ShortType()), + (2**15 - 1, ShortType()), + + # Integer + (-(2**31), IntegerType()), + (2**31 - 1, IntegerType()), + + # Long + (2**64, LongType()), + + # Float & Double + (1.0, FloatType()), + (1.0, DoubleType()), + + # Decimal + (decimal.Decimal("1.0"), DecimalType()), + + # Binary + (bytearray([1, 2]), BinaryType()), + + # Date/Timestamp + (datetime.date(2000, 1, 2), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), DateType()), + (datetime.datetime(2000, 1, 2, 3, 4), TimestampType()), + + # Array + ([], ArrayType(IntegerType())), + (["1", None], ArrayType(StringType(), containsNull=True)), + ([1, 2], ArrayType(IntegerType())), + ((1, 2), ArrayType(IntegerType())), + (array.array('h', [1, 2]), ArrayType(IntegerType())), + + # Map + ({}, MapType(StringType(), IntegerType())), + ({"a": 1}, MapType(StringType(), IntegerType())), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=True)), + + # Struct + ({"s": "a", "i": 1}, schema), + ({"s": "a", "i": None}, schema), + ({"s": "a"}, schema), + ({"s": "a", "f": 1.0}, schema), + (Row(s="a", i=1), schema), + (Row(s="a", i=None), schema), + (Row(s="a", i=1, f=1.0), schema), + (["a", 1], schema), + (["a", None], schema), + (("a", 1), schema), + (MyObj(s="a", i=1), schema), + (MyObj(s="a", i=None), schema), + (MyObj(s="a"), schema), + ] + + # obj, data_type, exception class + failure_spec = [ + # String (match anything but None) + (None, StringType(), ValueError), + + # UDT + (ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError), + + # Boolean + (1, BooleanType(), TypeError), + ("True", BooleanType(), TypeError), + ([1], BooleanType(), TypeError), + + # Byte + (-(2**7) - 1, ByteType(), ValueError), + (2**7, ByteType(), ValueError), + ("1", ByteType(), TypeError), + (1.0, ByteType(), TypeError), + + # Short + (-(2**15) - 1, ShortType(), ValueError), + (2**15, ShortType(), ValueError), + + # Integer + (-(2**31) - 1, IntegerType(), ValueError), + (2**31, IntegerType(), ValueError), + + # Float & Double + (1, FloatType(), TypeError), + (1, DoubleType(), TypeError), + + # Decimal + (1.0, DecimalType(), TypeError), + (1, DecimalType(), TypeError), + ("1.0", DecimalType(), TypeError), + + # Binary + (1, BinaryType(), TypeError), + + # Date/Timestamp + ("2000-01-02", DateType(), TypeError), + (946811040, TimestampType(), TypeError), + + # Array + (["1", None], ArrayType(StringType(), containsNull=False), ValueError), + ([1, "2"], ArrayType(IntegerType()), TypeError), + + # Map + ({"a": 1}, MapType(IntegerType(), IntegerType()), TypeError), + ({"a": "1"}, MapType(StringType(), IntegerType()), TypeError), + ({"a": None}, MapType(StringType(), IntegerType(), valueContainsNull=False), + ValueError), + + # Struct + ({"s": "a", "i": "1"}, schema, TypeError), + (Row(s="a"), schema, ValueError), # Row can't have missing field + (Row(s="a", i="1"), schema, TypeError), + (["a"], schema, ValueError), + (["a", "1"], schema, TypeError), + (MyObj(s="a", i="1"), schema, TypeError), + (MyObj(s=None, i="1"), schema, ValueError), + ] + + # Check success cases + for obj, data_type in success_spec: + try: + _make_type_verifier(data_type, nullable=False)(obj) + except Exception: + self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type)) + + # Check failure cases + for obj, data_type, exp in failure_spec: + msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp) + with self.assertRaises(exp, msg=msg): + _make_type_verifier(data_type, nullable=False)(obj) + + +@unittest.skipIf(not _have_arrow, "Arrow not installed") +class ArrowTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.schema = StructType([ + StructField("1_str_t", StringType(), True), + StructField("2_int_t", IntegerType(), True), + StructField("3_long_t", LongType(), True), + StructField("4_float_t", FloatType(), True), + StructField("5_double_t", DoubleType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0), + ("b", 2, 20, 0.4, 4.0), + ("c", 3, 30, 0.8, 6.0)] + + def assertFramesEqual(self, df_with_arrow, df_without): + msg = ("DataFrame from Arrow is not equal" + + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + + ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) + self.assertTrue(df_without.equals(df_with_arrow), msg=msg) + + def test_unsupported_datatype(self): + schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) + df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: df.toPandas()) + + def test_null_conversion(self): + df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] + + self.data) + pdf = df_null.toPandas() + null_counts = pdf.isnull().sum().tolist() + self.assertTrue(all([c == 1 for c in null_counts])) + + def test_toPandas_arrow_toggle(self): + df = self.spark.createDataFrame(self.data, schema=self.schema) + self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + pdf = df.toPandas() + self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_pandas_round_trip(self): + import pandas as pd + import numpy as np + data_dict = {} + for j, name in enumerate(self.schema.names): + data_dict[name] = [self.data[i][j] for i in range(len(self.data))] + # need to convert these to numpy types first + data_dict["2_int_t"] = np.int32(data_dict["2_int_t"]) + data_dict["4_float_t"] = np.float32(data_dict["4_float_t"]) + pdf = pd.DataFrame(data=data_dict) + df = self.spark.createDataFrame(self.data, schema=self.schema) + pdf_arrow = df.toPandas() + self.assertFramesEqual(pdf_arrow, pdf) + + def test_filtered_frame(self): + df = self.spark.range(3).toDF("i") + pdf = df.filter("i < 0").toPandas() + self.assertEqual(len(pdf.columns), 1) + self.assertEqual(pdf.columns[0], "i") + self.assertTrue(pdf.empty) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 26b54a7fb370..22fa273fc1aa 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -32,6 +32,7 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer __all__ = [ @@ -727,18 +728,6 @@ def __eq__(self, other): _BRACKETS = {'(': ')', '[': ']', '{': '}'} -def _parse_basic_datatype_string(s): - if s in _all_atomic_types.keys(): - return _all_atomic_types[s]() - elif s == "int": - return IntegerType() - elif _FIXED_DECIMAL.match(s): - m = _FIXED_DECIMAL.match(s) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % s) - - def _ignore_brackets_split(s, separator): """ Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. @@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator): return parts -def _parse_struct_fields_string(s): - parts = _ignore_brackets_split(s, ",") - fields = [] - for part in parts: - name_and_type = _ignore_brackets_split(part, ":") - if len(name_and_type) != 2: - raise ValueError("The strcut field string format is: 'field_name:field_type', " + - "but got: %s" % part) - field_name = name_and_type[0].strip() - field_type = _parse_datatype_string(name_and_type[1]) - fields.append(StructField(field_name, field_type)) - return StructType(fields) - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals to :class:`DataType.simpleString`, except that top level struct type can omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name - for :class:`IntegerType`. + for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted + string and case-insensitive strings. >>> _parse_datatype_string("int ") IntegerType + >>> _parse_datatype_string("INT ") + IntegerType >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ") StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true))) + >>> _parse_datatype_string("a DOUBLE, b STRING") + StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true))) >>> _parse_datatype_string("a: array< short>") StructType(List(StructField(a,ArrayType(ShortType,true),true))) >>> _parse_datatype_string(" map ") @@ -806,43 +786,43 @@ def _parse_datatype_string(s): >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... >>> _parse_datatype_string("array>> _parse_datatype_string("map>") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... """ - s = s.strip() - if s.startswith("array<"): - if s[-1] != ">": - raise ValueError("'>' should be the last char, but got: %s" % s) - return ArrayType(_parse_datatype_string(s[6:-1])) - elif s.startswith("map<"): - if s[-1] != ">": - raise ValueError("'>' should be the last char, but got: %s" % s) - parts = _ignore_brackets_split(s[4:-1], ",") - if len(parts) != 2: - raise ValueError("The map type string format is: 'map', " + - "but got: %s" % s) - kt = _parse_datatype_string(parts[0]) - vt = _parse_datatype_string(parts[1]) - return MapType(kt, vt) - elif s.startswith("struct<"): - if s[-1] != ">": - raise ValueError("'>' should be the last char, but got: %s" % s) - return _parse_struct_fields_string(s[7:-1]) - elif ":" in s: - return _parse_struct_fields_string(s) - else: - return _parse_basic_datatype_string(s) + sc = SparkContext._active_spark_context + + def from_ddl_schema(type_str): + return _parse_datatype_json_string( + sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json()) + + def from_ddl_datatype(type_str): + return _parse_datatype_json_string( + sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json()) + + try: + # DDL format, "fieldname datatype, fieldname datatype". + return from_ddl_schema(s) + except Exception as e: + try: + # For backwards compatibility, "integer", "struct" and etc. + return from_ddl_datatype(s) + except: + try: + # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. + return from_ddl_datatype("struct<%s>" % s.strip()) + except: + raise e def _parse_datatype_json_string(json_string): @@ -1249,121 +1229,196 @@ def _infer_schema_type(obj, dataType): } -def _verify_type(obj, dataType, nullable=True): +def _make_type_verifier(dataType, nullable=True, name=None): """ - Verify the type of obj against dataType, raise a TypeError if they do not match. - - Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed - range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it - will become infinity when cast to Java float if it overflows. - - >>> _verify_type(None, StructType([])) - >>> _verify_type("", StringType()) - >>> _verify_type(0, LongType()) - >>> _verify_type(list(range(3)), ArrayType(ShortType())) - >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL + Make a verifier that checks the type of obj against dataType and raises a TypeError if they do + not match. + + This verifier also checks the value of obj against datatype and raises a ValueError if it's not + within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is + not checked, so it will become infinity when cast to Java float if it overflows. + + >>> _make_type_verifier(StructType([]))(None) + >>> _make_type_verifier(StringType())("") + >>> _make_type_verifier(LongType())(0) + >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) + >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... - >>> _verify_type({}, MapType(StringType(), IntegerType())) - >>> _verify_type((), StructType([])) - >>> _verify_type([], StructType([])) - >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) + >>> _make_type_verifier(StructType([]))(()) + >>> _make_type_verifier(StructType([]))([]) + >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... >>> # Check if numeric values are within the allowed range. - >>> _verify_type(12, ByteType()) - >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType())(12) + >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier( + ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... - >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1}) Traceback (most recent call last): ... ValueError:... >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) - >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... """ - if obj is None: - if nullable: - return - else: - raise ValueError("This field is not nullable, but got None") - # StringType can work with any types - if isinstance(dataType, StringType): - return + if name is None: + new_msg = lambda msg: msg + new_name = lambda n: "field %s" % n + else: + new_msg = lambda msg: "%s: %s" % (name, msg) + new_name = lambda n: "field %s in %s" % (n, name) - if isinstance(dataType, UserDefinedType): - if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): - raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.toInternal(obj), dataType.sqlType()) - return + def verify_nullability(obj): + if obj is None: + if nullable: + return True + else: + raise ValueError(new_msg("This field is not nullable, but got None")) + else: + return False _type = type(dataType) - assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) - if _type is StructType: - # check the type and fields later - pass - else: + def assert_acceptable_types(obj): + assert _type in _acceptable_types, \ + new_msg("unknown datatype: %s for object %r" % (dataType, obj)) + + def verify_acceptable_types(obj): # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) + raise TypeError(new_msg("%s can not accept object %r in type %s" + % (dataType, obj, type(obj)))) + + if isinstance(dataType, StringType): + # StringType can work with any types + verify_value = lambda _: _ + + elif isinstance(dataType, UserDefinedType): + verifier = _make_type_verifier(dataType.sqlType(), name=name) - if isinstance(dataType, ByteType): - if obj < -128 or obj > 127: - raise ValueError("object of ByteType out of range, got: %s" % obj) + def verify_udf(obj): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType))) + verifier(dataType.toInternal(obj)) + + verify_value = verify_udf + + elif isinstance(dataType, ByteType): + def verify_byte(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -128 or obj > 127: + raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj)) + + verify_value = verify_byte elif isinstance(dataType, ShortType): - if obj < -32768 or obj > 32767: - raise ValueError("object of ShortType out of range, got: %s" % obj) + def verify_short(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -32768 or obj > 32767: + raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj)) + + verify_value = verify_short elif isinstance(dataType, IntegerType): - if obj < -2147483648 or obj > 2147483647: - raise ValueError("object of IntegerType out of range, got: %s" % obj) + def verify_integer(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + if obj < -2147483648 or obj > 2147483647: + raise ValueError( + new_msg("object of IntegerType out of range, got: %s" % obj)) + + verify_value = verify_integer elif isinstance(dataType, ArrayType): - for i in obj: - _verify_type(i, dataType.elementType, dataType.containsNull) + element_verifier = _make_type_verifier( + dataType.elementType, dataType.containsNull, name="element in array %s" % name) + + def verify_array(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for i in obj: + element_verifier(i) + + verify_value = verify_array elif isinstance(dataType, MapType): - for k, v in obj.items(): - _verify_type(k, dataType.keyType, False) - _verify_type(v, dataType.valueType, dataType.valueContainsNull) + key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) + value_verifier = _make_type_verifier( + dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name) + + def verify_map(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + for k, v in obj.items(): + key_verifier(k) + value_verifier(v) + + verify_value = verify_map elif isinstance(dataType, StructType): - if isinstance(obj, dict): - for f in dataType.fields: - _verify_type(obj.get(f.name), f.dataType, f.nullable) - elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): - # the order in obj could be different than dataType.fields - for f in dataType.fields: - _verify_type(obj[f.name], f.dataType, f.nullable) - elif isinstance(obj, (tuple, list)): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) - elif hasattr(obj, "__dict__"): - d = obj.__dict__ - for f in dataType.fields: - _verify_type(d.get(f.name), f.dataType, f.nullable) - else: - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + verifiers = [] + for f in dataType.fields: + verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) + verifiers.append((f.name, verifier)) + + def verify_struct(obj): + assert_acceptable_types(obj) + + if isinstance(obj, dict): + for f, verifier in verifiers: + verifier(obj.get(f)) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f, verifier in verifiers: + verifier(obj[f]) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(verifiers): + raise ValueError( + new_msg("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(verifiers)))) + for v, (_, verifier) in zip(obj, verifiers): + verifier(v) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f, verifier in verifiers: + verifier(d.get(f)) + else: + raise TypeError(new_msg("StructType can not accept object %r in type %s" + % (obj, type(obj)))) + verify_value = verify_struct + + else: + def verify_default(obj): + assert_acceptable_types(obj) + verify_acceptable_types(obj) + + verify_value = verify_default + + def verify(obj): + if not verify_nullability(obj): + verify_value(obj) + + return verify # This is used to unpickle a Row from JVM diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ef16eaa30afa..11526824da8c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1019,14 +1019,22 @@ def test_histogram(self): self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) - def test_repartitionAndSortWithinPartitions(self): + def test_repartitionAndSortWithinPartitions_asc(self): rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) - repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True) partitions = repartitioned.glom().collect() self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_repartitionAndSortWithinPartitions_desc(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False) + partitions = repartitioned.glom().collect() + self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)]) + self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)]) + def test_repartition_no_skewed(self): num_partitions = 20 a = self.sc.parallelize(range(int(1000)), 2) diff --git a/python/run-tests.py b/python/run-tests.py index b2e50435bb19..afd3d29a0ff9 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -54,7 +54,8 @@ def print_red(text): LOGGER = logging.getLogger() # Find out where the assembly jars are located. -for scala in ["2.11", "2.10"]: +# Later, add back 2.12 to this list: +for scala in ["2.11"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") diff --git a/python/setup.py b/python/setup.py index 2644d3e79dea..cfc83c68e3df 100644 --- a/python/setup.py +++ b/python/setup.py @@ -194,7 +194,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.4'], + install_requires=['py4j==0.10.6'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/repl/pom.xml b/repl/pom.xml index 6d133a3cfff7..51eb9b60dd54 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -32,8 +32,8 @@ repl - scala-2.10/src/main/scala - scala-2.10/src/test/scala + scala-2.11/src/main/scala + scala-2.11/src/test/scala @@ -71,7 +71,7 @@ ${scala.version} - ${jline.groupid} + jline jline @@ -170,23 +170,17 @@ + + + diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala deleted file mode 100644 index be9b79021d2a..000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.repl - -import scala.tools.nsc.{CompilerCommand, Settings} - -import org.apache.spark.annotation.DeveloperApi - -/** - * Command class enabling Spark-specific command line options (provided by - * org.apache.spark.repl.SparkRunnerSettings). - * - * @example new SparkCommandLine(Nil).settings - * - * @param args The list of command line arguments - * @param settings The underlying settings to associate with this set of - * command-line options - */ -@DeveloperApi -class SparkCommandLine(args: List[String], override val settings: Settings) - extends CompilerCommand(args, settings) { - def this(args: List[String], error: String => Unit) { - this(args, new SparkRunnerSettings(error)) - } - - def this(args: List[String]) { - // scalastyle:off println - this(args, str => Console.println("Error: " + str)) - // scalastyle:on println - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala deleted file mode 100644 index 2b5d56a89590..000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ /dev/null @@ -1,114 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package org.apache.spark.repl - -import scala.tools.nsc._ -import scala.tools.nsc.interpreter._ - -import scala.reflect.internal.util.BatchSourceFile -import scala.tools.nsc.ast.parser.Tokens.EOF - -import org.apache.spark.internal.Logging - -private[repl] trait SparkExprTyper extends Logging { - val repl: SparkIMain - - import repl._ - import global.{ reporter => _, Import => _, _ } - import definitions._ - import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name } - import naming.freshInternalVarName - - object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] { - def applyRule[T](code: String, rule: UnitParser => T): T = { - reporter.reset() - val scanner = newUnitParser(code) - val result = rule(scanner) - - if (!reporter.hasErrors) - scanner.accept(EOF) - - result - } - - def defns(code: String) = stmts(code) collect { case x: DefTree => x } - def expr(code: String) = applyRule(code, _.expr()) - def stmts(code: String) = applyRule(code, _.templateStats()) - def stmt(code: String) = stmts(code).last // guaranteed nonempty - } - - /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ - def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") { - var isIncomplete = false - reporter.withIncompleteHandler((_, _) => isIncomplete = true) { - val trees = codeParser.stmts(line) - if (reporter.hasErrors) { - Some(Nil) - } else if (isIncomplete) { - None - } else { - Some(trees) - } - } - } - // def parsesAsExpr(line: String) = { - // import codeParser._ - // (opt expr line).isDefined - // } - - def symbolOfLine(code: String): Symbol = { - def asExpr(): Symbol = { - val name = freshInternalVarName() - // Typing it with a lazy val would give us the right type, but runs - // into compiler bugs with things like existentials, so we compile it - // behind a def and strip the NullaryMethodType which wraps the expr. - val line = "def " + name + " = {\n" + code + "\n}" - - interpretSynthetic(line) match { - case IR.Success => - val sym0 = symbolOfTerm(name) - // drop NullaryMethodType - val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType) - if (sym.info.typeSymbol eq UnitClass) NoSymbol else sym - case _ => NoSymbol - } - } - def asDefn(): Symbol = { - val old = repl.definedSymbolList.toSet - - interpretSynthetic(code) match { - case IR.Success => - repl.definedSymbolList filterNot old match { - case Nil => NoSymbol - case sym :: Nil => sym - case syms => NoSymbol.newOverloaded(NoPrefix, syms) - } - case _ => NoSymbol - } - } - beQuietDuring(asExpr()) orElse beQuietDuring(asDefn()) - } - - private var typeOfExpressionDepth = 0 - def typeOfExpression(expr: String, silent: Boolean = true): Type = { - if (typeOfExpressionDepth > 2) { - logDebug("Terminating typeOfExpression recursion for expression: " + expr) - return NoType - } - typeOfExpressionDepth += 1 - // Don't presently have a good way to suppress undesirable success output - // while letting errors through, so it is first trying it silently: if there - // is an error, and errors are desired, then it re-evaluates non-silently - // to induce the error message. - try beSilentDuring(symbolOfLine(expr).tpe) match { - case NoType if !silent => symbolOfLine(expr).tpe // generate error - case tpe => tpe - } - finally typeOfExpressionDepth -= 1 - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala deleted file mode 100644 index b7237a6ce822..000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ /dev/null @@ -1,1145 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Alexander Spoon - */ - -package org.apache.spark.repl - - -import java.net.URL - -import scala.reflect.io.AbstractFile -import scala.tools.nsc._ -import scala.tools.nsc.backend.JavaPlatform -import scala.tools.nsc.interpreter._ -import scala.tools.nsc.interpreter.{Results => IR} -import Predef.{println => _, _} -import java.io.{BufferedReader, FileReader} -import java.net.URI -import java.util.concurrent.locks.ReentrantLock -import scala.sys.process.Process -import scala.tools.nsc.interpreter.session._ -import scala.util.Properties.{jdkHome, javaVersion} -import scala.tools.util.{Javap} -import scala.annotation.tailrec -import scala.collection.mutable.ListBuffer -import scala.concurrent.ops -import scala.tools.nsc.util._ -import scala.tools.nsc.interpreter._ -import scala.tools.nsc.io.{File, Directory} -import scala.reflect.NameTransformer._ -import scala.tools.nsc.util.ScalaClassLoader._ -import scala.tools.util._ -import scala.language.{implicitConversions, existentials, postfixOps} -import scala.reflect.{ClassTag, classTag} -import scala.tools.reflect.StdRuntimeTags._ - -import java.lang.{Class => jClass} -import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} - -import org.apache.spark.SparkConf -import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.util.Utils - -/** The Scala interactive shell. It provides a read-eval-print loop - * around the Interpreter class. - * After instantiation, clients should call the main() method. - * - * If no in0 is specified, then input will come from the console, and - * the class will attempt to provide input editing feature such as - * input history. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - * @version 1.2 - */ -@DeveloperApi -class SparkILoop( - private val in0: Option[BufferedReader], - protected val out: JPrintWriter, - val master: Option[String] -) extends AnyRef with LoopCommands with SparkILoopInit with Logging { - def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) - def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) - def this() = this(None, new JPrintWriter(Console.out, true), None) - - private var in: InteractiveReader = _ // the input stream from which commands come - - // NOTE: Exposed in package for testing - private[repl] var settings: Settings = _ - - private[repl] var intp: SparkIMain = _ - - @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp - @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i - - /** Having inherited the difficult "var-ness" of the repl instance, - * I'm trying to work around it by moving operations into a class from - * which it will appear a stable prefix. - */ - private def onIntp[T](f: SparkIMain => T): T = f(intp) - - class IMainOps[T <: SparkIMain](val intp: T) { - import intp._ - import global._ - - def printAfterTyper(msg: => String) = - intp.reporter printMessage afterTyper(msg) - - /** Strip NullaryMethodType artifacts. */ - private def replInfo(sym: Symbol) = { - sym.info match { - case NullaryMethodType(restpe) if sym.isAccessor => restpe - case info => info - } - } - def echoTypeStructure(sym: Symbol) = - printAfterTyper("" + deconstruct.show(replInfo(sym))) - - def echoTypeSignature(sym: Symbol, verbose: Boolean) = { - if (verbose) SparkILoop.this.echo("// Type signature") - printAfterTyper("" + replInfo(sym)) - - if (verbose) { - SparkILoop.this.echo("\n// Internal Type structure") - echoTypeStructure(sym) - } - } - } - implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp) - - /** TODO - - * -n normalize - * -l label with case class parameter names - * -c complete - leave nothing out - */ - private def typeCommandInternal(expr: String, verbose: Boolean): Result = { - onIntp { intp => - val sym = intp.symbolOfLine(expr) - if (sym.exists) intp.echoTypeSignature(sym, verbose) - else "" - } - } - - // NOTE: Must be public for visibility - @DeveloperApi - var sparkContext: SparkContext = _ - - override def echoCommandMessage(msg: String) { - intp.reporter printMessage msg - } - - // def isAsync = !settings.Yreplsync.value - private[repl] def isAsync = false - // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) - private def history = in.history - - /** The context class loader at the time this object was created */ - protected val originalClassLoader = Utils.getContextOrSparkClassLoader - - // classpath entries added via :cp - private var addedClasspath: String = "" - - /** A reverse list of commands to replay if the user requests a :replay */ - private var replayCommandStack: List[String] = Nil - - /** A list of commands to replay if the user requests a :replay */ - private def replayCommands = replayCommandStack.reverse - - /** Record a command for replay should the user request a :replay */ - private def addReplay(cmd: String) = replayCommandStack ::= cmd - - private def savingReplayStack[T](body: => T): T = { - val saved = replayCommandStack - try body - finally replayCommandStack = saved - } - private def savingReader[T](body: => T): T = { - val saved = in - try body - finally in = saved - } - - - private def sparkCleanUp() { - echo("Stopping spark context.") - intp.beQuietDuring { - command("sc.stop()") - } - } - /** Close the interpreter and set the var to null. */ - private def closeInterpreter() { - if (intp ne null) { - sparkCleanUp() - intp.close() - intp = null - } - } - - class SparkILoopInterpreter extends SparkIMain(settings, out) { - outer => - - override private[repl] lazy val formatting = new Formatting { - def prompt = SparkILoop.this.prompt - } - override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) - } - - /** - * Constructs a new interpreter. - */ - protected def createInterpreter() { - require(settings != null) - - if (addedClasspath != "") settings.classpath.append(addedClasspath) - val addedJars = - if (Utils.isWindows) { - // Strip any URI scheme prefix so we can add the correct path to the classpath - // e.g. file:/C:/my/path.jar -> C:/my/path.jar - getAddedJars().map { jar => new URI(jar).getPath.stripPrefix("/") } - } else { - // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). - getAddedJars().map { jar => new URI(jar).getPath } - } - // work around for Scala bug - val totalClassPath = addedJars.foldLeft( - settings.classpath.value)((l, r) => ClassPath.join(l, r)) - this.settings.classpath.value = totalClassPath - - intp = new SparkILoopInterpreter - } - - /** print a friendly help message */ - private def helpCommand(line: String): Result = { - if (line == "") helpSummary() - else uniqueCommand(line) match { - case Some(lc) => echo("\n" + lc.longHelp) - case _ => ambiguousError(line) - } - } - private def helpSummary() = { - val usageWidth = commands map (_.usageMsg.length) max - val formatStr = "%-" + usageWidth + "s %s %s" - - echo("All commands can be abbreviated, e.g. :he instead of :help.") - echo("Those marked with a * have more detailed help, e.g. :help imports.\n") - - commands foreach { cmd => - val star = if (cmd.hasLongHelp) "*" else " " - echo(formatStr.format(cmd.usageMsg, star, cmd.help)) - } - } - private def ambiguousError(cmd: String): Result = { - matchingCommands(cmd) match { - case Nil => echo(cmd + ": no such command. Type :help for help.") - case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") - } - Result(true, None) - } - private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) - private def uniqueCommand(cmd: String): Option[LoopCommand] = { - // this lets us add commands willy-nilly and only requires enough command to disambiguate - matchingCommands(cmd) match { - case List(x) => Some(x) - // exact match OK even if otherwise appears ambiguous - case xs => xs find (_.name == cmd) - } - } - private var fallbackMode = false - - private def toggleFallbackMode() { - val old = fallbackMode - fallbackMode = !old - System.setProperty("spark.repl.fallback", fallbackMode.toString) - echo(s""" - |Switched ${if (old) "off" else "on"} fallback mode without restarting. - | If you have defined classes in the repl, it would - |be good to redefine them incase you plan to use them. If you still run - |into issues it would be good to restart the repl and turn on `:fallback` - |mode as first command. - """.stripMargin) - } - - /** Show the history */ - private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { - override def usage = "[num]" - def defaultLines = 20 - - def apply(line: String): Result = { - if (history eq NoHistory) - return "No history available." - - val xs = words(line) - val current = history.index - val count = try xs.head.toInt catch { case _: Exception => defaultLines } - val lines = history.asStrings takeRight count - val offset = current - lines.size + 1 - - for ((line, index) <- lines.zipWithIndex) - echo("%3d %s".format(index + offset, line)) - } - } - - // When you know you are most likely breaking into the middle - // of a line being typed. This softens the blow. - private[repl] def echoAndRefresh(msg: String) = { - echo("\n" + msg) - in.redrawLine() - } - private[repl] def echo(msg: String) = { - out println msg - out.flush() - } - private def echoNoNL(msg: String) = { - out print msg - out.flush() - } - - /** Search the history */ - private def searchHistory(_cmdline: String) { - val cmdline = _cmdline.toLowerCase - val offset = history.index - history.size + 1 - - for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) - echo("%d %s".format(index + offset, line)) - } - - private var currentPrompt = Properties.shellPromptString - - /** - * Sets the prompt string used by the REPL. - * - * @param prompt The new prompt string - */ - @DeveloperApi - def setPrompt(prompt: String) = currentPrompt = prompt - - /** - * Represents the current prompt string used by the REPL. - * - * @return The current prompt string - */ - @DeveloperApi - def prompt = currentPrompt - - import LoopCommand.{ cmd, nullary } - - /** Standard commands */ - private lazy val standardCommands = List( - cmd("cp", "", "add a jar or directory to the classpath", addClasspath), - cmd("help", "[command]", "print this summary or command-specific help", helpCommand), - historyCommand, - cmd("h?", "", "search the history", searchHistory), - cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), - cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand), - cmd("javap", "", "disassemble a file or class name", javapCommand), - cmd("load", "", "load and interpret a Scala file", loadCommand), - nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand), -// nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the repl", () => Result(false, None)), - nullary("replay", "reset execution and replay all previous commands", replay), - nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), - shCommand, - nullary("silent", "disable/enable automatic printing of results", verbosity), - nullary("fallback", """ - |disable/enable advanced repl changes, these fix some issues but may introduce others. - |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), - cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), - nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) - ) - - /** Power user commands */ - private lazy val powerCommands: List[LoopCommand] = List( - // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) - ) - - // private def dumpCommand(): Result = { - // echo("" + power) - // history.asStrings takeRight 30 foreach echo - // in.redrawLine() - // } - // private def valsCommand(): Result = power.valsDescription - - private val typeTransforms = List( - "scala.collection.immutable." -> "immutable.", - "scala.collection.mutable." -> "mutable.", - "scala.collection.generic." -> "generic.", - "java.lang." -> "jl.", - "scala.runtime." -> "runtime." - ) - - private def importsCommand(line: String): Result = { - val tokens = words(line) - val handlers = intp.languageWildcardHandlers ++ intp.importHandlers - val isVerbose = tokens contains "-v" - - handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { - case (handler, idx) => - val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) - val imps = handler.implicitSymbols - val found = tokens filter (handler importsSymbolNamed _) - val typeMsg = if (types.isEmpty) "" else types.size + " types" - val termMsg = if (terms.isEmpty) "" else terms.size + " terms" - val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" - val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") - val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - - intp.reporter.printMessage("%2d) %-30s %s%s".format( - idx + 1, - handler.importString, - statsMsg, - foundMsg - )) - } - } - - private def implicitsCommand(line: String): Result = onIntp { intp => - import intp._ - import global._ - - def p(x: Any) = intp.reporter.printMessage("" + x) - - // If an argument is given, only show a source with that - // in its name somewhere. - val args = line split "\\s+" - val filtered = intp.implicitSymbolsBySource filter { - case (source, syms) => - (args contains "-v") || { - if (line == "") (source.fullName.toString != "scala.Predef") - else (args exists (source.name.toString contains _)) - } - } - - if (filtered.isEmpty) - return "No implicits have been imported other than those in Predef." - - filtered foreach { - case (source, syms) => - p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") - - // This groups the members by where the symbol is defined - val byOwner = syms groupBy (_.owner) - val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) } - - sortedOwners foreach { - case (owner, members) => - // Within each owner, we cluster results based on the final result type - // if there are more than a couple, and sort each cluster based on name. - // This is really just trying to make the 100 or so implicits imported - // by default into something readable. - val memberGroups: List[List[Symbol]] = { - val groups = members groupBy (_.tpe.finalResultType) toList - val (big, small) = groups partition (_._2.size > 3) - val xss = ( - (big sortBy (_._1.toString) map (_._2)) :+ - (small flatMap (_._2)) - ) - - xss map (xs => xs sortBy (_.name.toString)) - } - - val ownerMessage = if (owner == source) " defined in " else " inherited from " - p(" /* " + members.size + ownerMessage + owner.fullName + " */") - - memberGroups foreach { group => - group foreach (s => p(" " + intp.symbolDefString(s))) - p("") - } - } - p("") - } - } - - private def findToolsJar() = { - val jdkPath = Directory(jdkHome) - val jar = jdkPath / "lib" / "tools.jar" toFile; - - if (jar isFile) - Some(jar) - else if (jdkPath.isDirectory) - jdkPath.deepFiles find (_.name == "tools.jar") - else None - } - private def addToolsJarToLoader() = { - val cl = findToolsJar match { - case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) - case _ => intp.classLoader - } - if (Javap.isAvailable(cl)) { - logDebug(":javap available.") - cl - } - else { - logDebug(":javap unavailable: no tools.jar at " + jdkHome) - intp.classLoader - } - } - - private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { - override def tryClass(path: String): Array[Byte] = { - val hd :: rest = path split '.' toList; - // If there are dots in the name, the first segment is the - // key to finding it. - if (rest.nonEmpty) { - intp optFlatName hd match { - case Some(flat) => - val clazz = flat :: rest mkString NAME_JOIN_STRING - val bytes = super.tryClass(clazz) - if (bytes.nonEmpty) bytes - else super.tryClass(clazz + MODULE_SUFFIX_STRING) - case _ => super.tryClass(path) - } - } - else { - // Look for Foo first, then Foo$, but if Foo$ is given explicitly, - // we have to drop the $ to find object Foo, then tack it back onto - // the end of the flattened name. - def className = intp flatName path - def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING - - val bytes = super.tryClass(className) - if (bytes.nonEmpty) bytes - else super.tryClass(moduleName) - } - } - } - // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) - private lazy val javap = - try newJavap() - catch { case _: Exception => null } - - // Still todo: modules. - private def typeCommand(line0: String): Result = { - line0.trim match { - case "" => ":type [-v] " - case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true) - case s => typeCommandInternal(s, false) - } - } - - private def warningsCommand(): Result = { - if (intp.lastWarnings.isEmpty) - "Can't find any cached warnings." - else - intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } - } - - private def javapCommand(line: String): Result = { - if (javap == null) - ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) - else if (javaVersion startsWith "1.7") - ":javap not yet working with java 1.7" - else if (line == "") - ":javap [-lcsvp] [path1 path2 ...]" - else - javap(words(line)) foreach { res => - if (res.isError) return "Failed: " + res.value - else res.show() - } - } - - private def wrapCommand(line: String): Result = { - def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T" - onIntp { intp => - import intp._ - import global._ - - words(line) match { - case Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => "Current execution wrapper: " + s - } - case "clear" :: Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." - } - case wrapper :: Nil => - intp.typeOfExpression(wrapper) match { - case PolyType(List(targ), MethodType(List(arg), restpe)) => - intp setExecutionWrapper intp.pathToTerm(wrapper) - "Set wrapper to '" + wrapper + "'" - case tp => - failMsg + "\nFound: " - } - case _ => failMsg - } - } - } - - private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent" - // private def phaseCommand(name: String): Result = { - // val phased: Phased = power.phased - // import phased.NoPhaseName - - // if (name == "clear") { - // phased.set(NoPhaseName) - // intp.clearExecutionWrapper() - // "Cleared active phase." - // } - // else if (name == "") phased.get match { - // case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" - // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) - // } - // else { - // val what = phased.parse(name) - // if (what.isEmpty || !phased.set(what)) - // "'" + name + "' does not appear to represent a valid phase." - // else { - // intp.setExecutionWrapper(pathToPhaseWrapper) - // val activeMessage = - // if (what.toString.length == name.length) "" + what - // else "%s (%s)".format(what, name) - - // "Active phase is now: " + activeMessage - // } - // } - // } - - /** - * Provides a list of available commands. - * - * @return The list of commands - */ - @DeveloperApi - def commands: List[LoopCommand] = standardCommands /*++ ( - if (isReplPower) powerCommands else Nil - )*/ - - private val replayQuestionMessage = - """|That entry seems to have slain the compiler. Shall I replay - |your session? I can re-run each line except the last one. - |[y/n] - """.trim.stripMargin - - private def crashRecovery(ex: Throwable): Boolean = { - echo(ex.toString) - ex match { - case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("\nUnrecoverable error.") - throw ex - case _ => - def fn(): Boolean = - try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) - catch { case _: RuntimeException => false } - - if (fn()) replay() - else echo("\nAbandoning crashed session.") - } - true - } - - /** The main read-eval-print loop for the repl. It calls - * command() for each line of input, and stops when - * command() returns false. - */ - private def loop() { - def readOneLine() = { - out.flush() - in readLine prompt - } - // return false if repl should exit - def processLine(line: String): Boolean = { - if (isAsync) { - if (!awaitInitialized()) return false - runThunks() - } - if (line eq null) false // assume null means EOF - else command(line) match { - case Result(false, _) => false - case Result(_, Some(finalLine)) => addReplay(finalLine) ; true - case _ => true - } - } - def innerLoop() { - val shouldContinue = try { - processLine(readOneLine()) - } catch {case t: Throwable => crashRecovery(t)} - if (shouldContinue) - innerLoop() - } - innerLoop() - } - - /** interpret all lines from a specified file */ - private def interpretAllFrom(file: File) { - savingReader { - savingReplayStack { - file applyReader { reader => - in = SimpleReader(reader, out, false) - echo("Loading " + file + "...") - loop() - } - } - } - } - - /** create a new interpreter and replay the given commands */ - private def replay() { - reset() - if (replayCommandStack.isEmpty) - echo("Nothing to replay.") - else for (cmd <- replayCommands) { - echo("Replaying: " + cmd) // flush because maybe cmd will have its own output - command(cmd) - echo("") - } - } - private def resetCommand() { - echo("Resetting repl state.") - if (replayCommandStack.nonEmpty) { - echo("Forgetting this session history:\n") - replayCommands foreach echo - echo("") - replayCommandStack = Nil - } - if (intp.namedDefinedTerms.nonEmpty) - echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) - if (intp.definedTypes.nonEmpty) - echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) - - reset() - } - - private def reset() { - intp.reset() - // unleashAndSetPhase() - } - - /** fork a shell and run a command */ - private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { - override def usage = "" - def apply(line: String): Result = line match { - case "" => showUsage() - case _ => - val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" - intp interpret toRun - () - } - } - - private def withFile(filename: String)(action: File => Unit) { - val f = File(filename) - - if (f.exists) action(f) - else echo("That file does not exist") - } - - private def loadCommand(arg: String) = { - var shouldReplay: Option[String] = None - withFile(arg)(f => { - interpretAllFrom(f) - shouldReplay = Some(":load " + arg) - }) - Result(true, shouldReplay) - } - - private def addAllClasspath(args: Seq[String]): Unit = { - var added = false - var totalClasspath = "" - for (arg <- args) { - val f = File(arg).normalize - if (f.exists) { - added = true - addedClasspath = ClassPath.join(addedClasspath, f.path) - totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - intp.addUrlsToClassPath(f.toURI.toURL) - sparkContext.addJar(f.toURI.toURL.getPath) - } - } - } - - private def addClasspath(arg: String): Unit = { - val f = File(arg).normalize - if (f.exists) { - addedClasspath = ClassPath.join(addedClasspath, f.path) - intp.addUrlsToClassPath(f.toURI.toURL) - sparkContext.addJar(f.toURI.toURL.getPath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString)) - } - else echo("The path '" + f + "' doesn't seem to exist.") - } - - - private def powerCmd(): Result = { - if (isReplPower) "Already in power mode." - else enablePowerMode(false) - } - - private[repl] def enablePowerMode(isDuringInit: Boolean) = { - // replProps.power setValue true - // unleashAndSetPhase() - // asyncEcho(isDuringInit, power.banner) - } - // private def unleashAndSetPhase() { -// if (isReplPower) { -// // power.unleash() -// // Set the phase to "typer" -// intp beSilentDuring phaseCommand("typer") -// } -// } - - private def asyncEcho(async: Boolean, msg: => String) { - if (async) asyncMessage(msg) - else echo(msg) - } - - private def verbosity() = { - // val old = intp.printResults - // intp.printResults = !old - // echo("Switched " + (if (old) "off" else "on") + " result printing.") - } - - /** - * Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. - */ - private[repl] def command(line: String): Result = { - if (line startsWith ":") { - val cmd = line.tail takeWhile (x => !x.isWhitespace) - uniqueCommand(cmd) match { - case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) - case _ => ambiguousError(cmd) - } - } - else if (intp.global == null) Result(false, None) // Notice failure to create compiler - else Result(true, interpretStartingWith(line)) - } - - private def readWhile(cond: String => Boolean) = { - Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) - } - - private def pasteCommand(): Result = { - echo("// Entering paste mode (ctrl-D to finish)\n") - val code = readWhile(_ => true) mkString "\n" - echo("\n// Exiting paste mode, now interpreting.\n") - intp interpret code - () - } - - private object paste extends Pasted { - val ContinueString = " | " - val PromptString = "scala> " - - def interpret(line: String): Unit = { - echo(line.trim) - intp interpret line - echo("") - } - - def transcript(start: String) = { - echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") - apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) - } - } - import paste.{ ContinueString, PromptString } - - /** - * Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ - private def interpretStartingWith(code: String): Option[String] = { - // signal completion non-completion input has been received - in.completion.resetVerbosity() - - def reallyInterpret = { - val reallyResult = intp.interpret(code) - (reallyResult, reallyResult match { - case IR.Error => None - case IR.Success => Some(code) - case IR.Incomplete => - if (in.interactive && code.endsWith("\n\n")) { - echo("You typed two blank lines. Starting a new command.") - None - } - else in.readLine(ContinueString) match { - case null => - // we know compilation is going to fail since we're at EOF and the - // parser thinks the input is still incomplete, but since this is - // a file being read non-interactively we want to fail. So we send - // it straight to the compiler for the nice error message. - intp.compileString(code) - None - - case line => interpretStartingWith(code + "\n" + line) - } - }) - } - - /** Here we place ourselves between the user and the interpreter and examine - * the input they are ostensibly submitting. We intervene in several cases: - * - * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. - * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation - * on the previous result. - * 3) If the Completion object's execute returns Some(_), we inject that value - * and avoid the interpreter, as it's likely not valid scala code. - */ - if (code == "") None - else if (!paste.running && code.trim.startsWith(PromptString)) { - paste.transcript(code) - None - } - else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { - interpretStartingWith(intp.mostRecentVar + code) - } - else if (code.trim startsWith "//") { - // line comment, do nothing - None - } - else - reallyInterpret._2 - } - - // runs :load `file` on any files passed via -i - private def loadFiles(settings: Settings) = settings match { - case settings: SparkRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - addReplay(cmd) - echo("") - } - case _ => - } - - /** Tries to create a JLineReader, falling back to SimpleReader: - * unless settings or properties are such that it should start - * with SimpleReader. - */ - private def chooseReader(settings: Settings): InteractiveReader = { - if (settings.Xnojline.value || Properties.isEmacsShell) - SimpleReader() - else try new SparkJLineReader( - if (settings.noCompletion.value) NoCompletion - else new SparkJLineCompletion(intp) - ) - catch { - case ex @ (_: Exception | _: NoClassDefFoundError) => - echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.") - SimpleReader() - } - } - - private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - private val m = u.runtimeMirror(Utils.getSparkClassLoader) - private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = - u.TypeTag[T]( - m, - new TypeCreator { - def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type = - m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] - }) - - private def process(settings: Settings): Boolean = savingContextLoader { - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0 match { - case Some(reader) => SimpleReader(reader, out, true) - case None => - // some post-initialization - chooseReader(settings) match { - case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x - case x => x - } - } - lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain] - // Bind intp somewhere out of the regular namespace where - // we can get at it in generated code. - addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain]))) - addThunk({ - import scala.tools.nsc.io._ - import Properties.userHome - import scala.compat.Platform.EOL - val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp()) - if (autorun.isDefined) intp.quietRun(autorun.get) - }) - - addThunk(printWelcome()) - addThunk(initializeSpark()) - - // it is broken on startup; go ahead and exit - if (intp.reporter.hasErrors) - return false - - // This is about the illusion of snappiness. We call initialize() - // which spins off a separate thread, then print the prompt and try - // our best to look ready. The interlocking lazy vals tend to - // inter-deadlock, so we break the cycle with a single asynchronous - // message to an rpcEndpoint. - if (isAsync) { - intp initialize initializedCallback() - createAsyncListener() // listens for signal to run postInitialization - } - else { - intp.initializeSynchronous() - postInitialization() - } - // printWelcome() - - loadFiles(settings) - - try loop() - catch AbstractOrMissingHandler() - finally closeInterpreter() - - true - } - - // NOTE: Must be public for visibility - @DeveloperApi - def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - val jars = getAddedJars() - val conf = new SparkConf() - .setMaster(getMaster()) - .setJars(jars) - .setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - .set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - - val builder = SparkSession.builder.config(conf) - val sparkSession = if (SparkSession.hiveClassesArePresent) { - logInfo("Creating Spark session with Hive support") - builder.enableHiveSupport().getOrCreate() - } else { - logInfo("Creating Spark session") - builder.getOrCreate() - } - sparkContext = sparkSession.sparkContext - sparkSession - } - - private def getMaster(): String = { - val master = this.master match { - case Some(m) => m - case None => - val envMaster = sys.env.get("MASTER") - val propMaster = sys.props.get("spark.master") - propMaster.orElse(envMaster).getOrElse("local[*]") - } - master - } - - /** process command-line arguments and do as they request */ - def process(args: Array[String]): Boolean = { - val command = new SparkCommandLine(args.toList, msg => echo(msg)) - def neededHelp(): String = - (if (command.settings.help.value) command.usageMsg + "\n" else "") + - (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") - - // if they asked for no help and command is valid, we call the real main - neededHelp() match { - case "" => command.ok && process(command.settings) - case help => echoNoNL(help) ; true - } - } - - @deprecated("Use `process` instead", "2.9.0") - private def main(settings: Settings): Unit = process(settings) - - @DeveloperApi - def getAddedJars(): Array[String] = { - val conf = new SparkConf().setMaster(getMaster()) - val envJars = sys.env.get("ADD_JARS") - if (envJars.isDefined) { - logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") - } - val jars = { - val userJars = Utils.getUserJars(conf, isShell = true) - if (userJars.isEmpty) { - envJars.getOrElse("") - } else { - userJars.mkString(",") - } - } - Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) - } - -} - -object SparkILoop extends Logging { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - private def echo(msg: String) = Console println msg - - // Designed primarily for use by test code: take a String with a - // bunch of code, and prints out a transcript of what it would look - // like if you'd just typed it into the repl. - private[repl] def runForTranscript(code: String, settings: Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { - override def write(str: String) = { - // completely skip continuation lines - if (str forall (ch => ch.isWhitespace || ch == '|')) () - // print a newline on empty scala prompts - else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n") - else super.write(str) - } - } - val input = new BufferedReader(new StringReader(code)) { - override def readLine(): String = { - val s = super.readLine() - // helping out by printing the line being interpreted. - if (s != null) - // scalastyle:off println - output.println(s) - // scalastyle:on println - s - } - } - val repl = new SparkILoop(input, output) - - if (settings.classpath.isDefault) - settings.classpath.value = sys.props("java.class.path") - - repl.getAddedJars().map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) - - repl process settings - } - } - } - - /** Creates an interpreter loop with default settings and feeds - * the given code to it as input. - */ - private[repl] def run(code: String, sets: Settings = new Settings): String = { - import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - - stringFromStream { ostream => - Console.withOut(ostream) { - val input = new BufferedReader(new StringReader(code)) - val output = new JPrintWriter(new OutputStreamWriter(ostream), true) - val repl = new ILoop(input, output) - - if (sets.classpath.isDefault) - sets.classpath.value = sys.props("java.class.path") - - repl process sets - } - } - } - private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala deleted file mode 100644 index 5f0d92bccd80..000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ /dev/null @@ -1,168 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Paul Phillips - */ - -package org.apache.spark.repl - -import scala.tools.nsc._ -import scala.tools.nsc.interpreter._ - -import scala.tools.nsc.util.stackTraceString - -import org.apache.spark.SPARK_VERSION - -/** - * Machinery for the asynchronous initialization of the repl. - */ -private[repl] trait SparkILoopInit { - self: SparkILoop => - - /** Print a welcome message */ - def printWelcome() { - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version %s - /_/ -""".format(SPARK_VERSION)) - import Properties._ - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } - - protected def asyncMessage(msg: String) { - if (isReplInfo || isReplPower) - echoAndRefresh(msg) - } - - private val initLock = new java.util.concurrent.locks.ReentrantLock() - private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized - private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized - private val initStart = System.nanoTime - - private def withLock[T](body: => T): T = { - initLock.lock() - try body - finally initLock.unlock() - } - // a condition used to ensure serial access to the compiler. - @volatile private var initIsComplete = false - @volatile private var initError: String = null - private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L) - - // the method to be called when the interpreter is initialized. - // Very important this method does nothing synchronous (i.e. do - // not try to use the interpreter) because until it returns, the - // repl's lazy val `global` is still locked. - protected def initializedCallback() = withLock(initCompilerCondition.signal()) - - // Spins off a thread which awaits a single message once the interpreter - // has been initialized. - protected def createAsyncListener() = { - io.spawn { - withLock(initCompilerCondition.await()) - asyncMessage("[info] compiler init time: " + elapsed() + " s.") - postInitialization() - } - } - - // called from main repl loop - protected def awaitInitialized(): Boolean = { - if (!initIsComplete) - withLock { while (!initIsComplete) initLoopCondition.await() } - if (initError != null) { - // scalastyle:off println - println(""" - |Failed to initialize the REPL due to an unexpected error. - |This is a bug, please, report it along with the error diagnostics printed below. - |%s.""".stripMargin.format(initError) - ) - // scalastyle:on println - false - } else true - } - // private def warningsThunks = List( - // () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _), - // ) - - protected def postInitThunks = List[Option[() => Unit]]( - Some(intp.setContextClassLoader _), - if (isReplPower) Some(() => enablePowerMode(true)) else None - ).flatten - // ++ ( - // warningsThunks - // ) - // called once after init condition is signalled - protected def postInitialization() { - try { - postInitThunks foreach (f => addThunk(f())) - runThunks() - } catch { - case ex: Throwable => - initError = stackTraceString(ex) - throw ex - } finally { - initIsComplete = true - - if (isAsync) { - asyncMessage("[info] total init time: " + elapsed() + " s.") - withLock(initLoopCondition.signal()) - } - } - } - - def initializeSpark() { - intp.beQuietDuring { - command(""" - @transient val spark = org.apache.spark.repl.Main.interp.createSparkSession() - @transient val sc = { - val _sc = spark.sparkContext - if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { - val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) - if (proxyUrl != null) { - println(s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") - } else { - println(s"Spark Context Web UI is available at Spark Master Public URL") - } - } else { - _sc.uiWebUrl.foreach { - webUrl => println(s"Spark context Web UI available at ${webUrl}") - } - } - println("Spark context available as 'sc' " + - s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - println("Spark session available as 'spark'.") - _sc - } - """) - command("import org.apache.spark.SparkContext._") - command("import spark.implicits._") - command("import spark.sql") - command("import org.apache.spark.sql.functions._") - } - } - - // code to be executed only after the interpreter is initialized - // and the lazy val `global` can be accessed without risk of deadlock. - private var pendingThunks: List[() => Unit] = Nil - protected def addThunk(body: => Unit) = synchronized { - pendingThunks :+= (() => body) - } - protected def runThunks(): Unit = synchronized { - if (pendingThunks.nonEmpty) - logDebug("Clearing " + pendingThunks.size + " thunks.") - - while (pendingThunks.nonEmpty) { - val thunk = pendingThunks.head - pendingThunks = pendingThunks.tail - thunk() - } - } -} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala deleted file mode 100644 index 74a04d5a42bb..000000000000 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ /dev/null @@ -1,1808 +0,0 @@ -// scalastyle:off - -/* NSC -- new Scala compiler - * Copyright 2005-2013 LAMP/EPFL - * @author Martin Odersky - */ - -package org.apache.spark.repl - -import java.io.File - -import scala.tools.nsc._ -import scala.tools.nsc.backend.JavaPlatform -import scala.tools.nsc.interpreter._ - -import Predef.{ println => _, _ } -import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString} -import scala.reflect.internal.util._ -import java.net.URL -import scala.sys.BooleanProp -import io.{AbstractFile, PlainFile, VirtualDirectory} - -import reporters._ -import symtab.Flags -import scala.reflect.internal.Names -import scala.tools.util.PathResolver -import ScalaClassLoader.URLClassLoader -import scala.tools.nsc.util.Exceptional.unwrap -import scala.collection.{ mutable, immutable } -import scala.util.control.Exception.{ ultimately } -import SparkIMain._ -import java.util.concurrent.Future -import typechecker.Analyzer -import scala.language.implicitConversions -import scala.reflect.runtime.{ universe => ru } -import scala.reflect.{ ClassTag, classTag } -import scala.tools.reflect.StdRuntimeTags._ -import scala.util.control.ControlThrowable - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils -import org.apache.spark.annotation.DeveloperApi - -// /** directory to save .class files to */ -// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { -// private def pp(root: AbstractFile, indentLevel: Int) { -// val spaces = " " * indentLevel -// out.println(spaces + root.name) -// if (root.isDirectory) -// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) -// } -// // print the contents hierarchically -// def show() = pp(this, 0) -// } - - /** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports members called "$eval" and "$print". To accommodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ - @DeveloperApi - class SparkIMain( - initialSettings: Settings, - val out: JPrintWriter, - propagateExceptions: Boolean = false) - extends SparkImports with Logging { imain => - - private val conf = new SparkConf() - - private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") - /** Local directory to save .class files too */ - private[repl] val outputDir = { - val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf)) - Utils.createTempDir(root = rootDir, namePrefix = "repl") - } - if (SPARK_DEBUG_REPL) { - echo("Output directory: " + outputDir) - } - - /** - * Returns the path to the output directory containing all generated - * class files that will be served by the REPL class server. - */ - @DeveloperApi - lazy val getClassOutputDirectory = outputDir - - private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - /** Jetty server that will serve our classes to worker nodes */ - private var currentSettings: Settings = initialSettings - private var printResults = true // whether to print result lines - private var totalSilence = false // whether to print anything - private var _initializeComplete = false // compiler is initialized - private var _isInitialized: Future[Boolean] = null // set up initialization future - private var bindExceptions = true // whether to bind the lastException variable - private var _executionWrapper = "" // code to be wrapped around all lines - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. - */ - private var _classLoader: AbstractFileClassLoader = null // active classloader - private val _compiler: Global = newCompiler(settings, reporter) // our private compiler - - private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) } - private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL - - private val nextReqId = { - var counter = 0 - () => { counter += 1 ; counter } - } - - private def compilerClasspath: Seq[URL] = ( - if (isInitializeComplete) global.classPath.asURLs - else new PathResolver(settings).result.asURLs // the compiler's classpath - ) - // NOTE: Exposed to repl package since accessed indirectly from SparkIMain - private[repl] def settings = currentSettings - private def mostRecentLine = prevRequestList match { - case Nil => "" - case req :: _ => req.originalLine - } - // Run the code body with the given boolean settings flipped to true. - private def withoutWarnings[T](body: => T): T = beQuietDuring { - val saved = settings.nowarn.value - if (!saved) - settings.nowarn.value = true - - try body - finally if (!saved) settings.nowarn.value = false - } - - /** construct an interpreter that reports to Console */ - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this() = this(new Settings()) - - private lazy val repllog: Logger = new Logger { - val out: JPrintWriter = imain.out - val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" - val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" - val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" - } - private[repl] lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - - // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop - private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) - - /** - * Determines if errors were reported (typically during compilation). - * - * @note This is not for runtime errors - * - * @return True if had errors, otherwise false - */ - @DeveloperApi - def isReportingErrors = reporter.hasErrors - - import formatting._ - import reporter.{ printMessage, withoutTruncating } - - // This exists mostly because using the reporter too early leads to deadlock. - private def echo(msg: String) { Console println msg } - private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) - private def _initialize() = { - try { - // todo. if this crashes, REPL will hang - new _compiler.Run() compileSources _initSources - _initializeComplete = true - true - } - catch AbstractOrMissingHandler() - } - private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - - // argument is a thunk to execute after init is done - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def initialize(postInitSignal: => Unit) { - synchronized { - if (_isInitialized == null) { - _isInitialized = io.spawn { - try _initialize() - finally postInitSignal - } - } - } - } - - /** - * Initializes the underlying compiler/interpreter in a blocking fashion. - * - * @note Must be executed before using SparkIMain! - */ - @DeveloperApi - def initializeSynchronous(): Unit = { - if (!isInitializeComplete) { - _initialize() - assert(global != null, global) - } - } - private def isInitializeComplete = _initializeComplete - - /** the public, go through the future compiler */ - - /** - * The underlying compiler used to generate ASTs and execute code. - */ - @DeveloperApi - lazy val global: Global = { - if (isInitializeComplete) _compiler - else { - // If init hasn't been called yet you're on your own. - if (_isInitialized == null) { - logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.") - initialize(()) - } - // // blocks until it is ; false means catastrophic failure - if (_isInitialized.get()) _compiler - else null - } - } - @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - private lazy val compiler: global.type = global - - import global._ - import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} - import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} - - private implicit class ReplTypeOps(tp: Type) { - def orElse(other: => Type): Type = if (tp ne NoType) tp else other - def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) - } - - // TODO: If we try to make naming a lazy val, we run into big time - // scalac unhappiness with what look like cycles. It has not been easy to - // reduce, but name resolution clearly takes different paths. - // NOTE: Exposed to repl package since used by SparkExprTyper - private[repl] object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - def freshUserTermName(): TermName = { - val name = newTermName(freshUserVarName()) - if (definedNameMap contains name) freshUserTermName() - else name - } - def isUserTermName(name: Name) = isUserVarName("" + name) - def isInternalTermName(name: Name) = isInternalVarName("" + name) - } - import naming._ - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] object deconstruct extends { - val global: imain.global.type = imain.global - } with StructuredTypeStrings - - // NOTE: Exposed to repl package since used by SparkImports - private[repl] lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - /** - * Suppresses overwriting print results during the operation. - * - * @param body The block to execute - * @tparam T The return type of the block - * - * @return The result from executing the block - */ - @DeveloperApi - def beQuietDuring[T](body: => T): T = { - val saved = printResults - printResults = false - try body - finally printResults = saved - } - - /** - * Completely masks all output during the operation (minus JVM standard - * out and error). - * - * @param operation The block to execute - * @tparam T The return type of the block - * - * @return The result from executing the block - */ - @DeveloperApi - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { - case t: ControlThrowable => throw t - case t: Throwable => - logDebug(label + ": " + unwrap(t)) - logDebug(stackTraceString(unwrap(t))) - alt - } - /** takes AnyRef because it may be binding a Throwable or an Exceptional */ - - private def withLastExceptionLock[T](body: => T, alt: => T): T = { - assert(bindExceptions, "withLastExceptionLock called incorrectly.") - bindExceptions = false - - try beQuietDuring(body) - catch logAndDiscard("withLastExceptionLock", alt) - finally bindExceptions = true - } - - /** - * Contains the code (in string form) representing a wrapper around all - * code executed by this instance. - * - * @return The wrapper code as a string - */ - @DeveloperApi - def executionWrapper = _executionWrapper - - /** - * Sets the code to use as a wrapper around all code executed by this - * instance. - * - * @param code The wrapper code as a string - */ - @DeveloperApi - def setExecutionWrapper(code: String) = _executionWrapper = code - - /** - * Clears the code used as a wrapper around all code executed by - * this instance. - */ - @DeveloperApi - def clearExecutionWrapper() = _executionWrapper = "" - - /** interpreter settings */ - private lazy val isettings = new SparkISettings(this) - - /** - * Instantiates a new compiler used by SparkIMain. Overridable to provide - * own instance of a compiler. - * - * @param settings The settings to provide the compiler - * @param reporter The reporter to use for compiler output - * - * @return The compiler as a Global - */ - @DeveloperApi - protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { - settings.outputDirs setSingleOutput virtualDirectory - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) with ReplGlobal { - override def toString: String = "" - } - } - - /** - * Adds any specified jars to the compile and runtime classpaths. - * - * @note Currently only supports jars, not directories - * @param urls The list of items to add to the compile and runtime classpaths - */ - @DeveloperApi - def addUrlsToClassPath(urls: URL*): Unit = { - new Run // Needed to force initialization of "something" to correctly load Scala classes from jars - urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution - updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling - } - - private def updateCompilerClassPath(urls: URL*): Unit = { - require(!global.forMSIL) // Only support JavaPlatform - - val platform = global.platform.asInstanceOf[JavaPlatform] - - val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*) - - // NOTE: Must use reflection until this is exposed/fixed upstream in Scala - val fieldSetter = platform.getClass.getMethods - .find(_.getName.endsWith("currentClassPath_$eq")).get - fieldSetter.invoke(platform, Some(newClassPath)) - - // Reload all jars specified into our compiler - global.invalidateClassPathEntries(urls.map(_.getPath): _*) - } - - private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { - // Collect our new jars/directories and add them to the existing set of classpaths - val allClassPaths = ( - platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ - urls.map(url => { - platform.classPath.context.newClassPath( - if (url.getProtocol == "file") { - val f = new File(url.getPath) - if (f.isDirectory) - io.AbstractFile.getDirectory(f) - else - io.AbstractFile.getFile(f) - } else { - io.AbstractFile.getURL(url) - } - ) - }) - ).distinct - - // Combine all of our classpaths (old and new) into one merged classpath - new MergedClassPath(allClassPaths, platform.classPath.context) - } - - /** - * Represents the parent classloader used by this instance. Can be - * overridden to provide alternative classloader. - * - * @return The classloader used as the parent loader of this instance - */ - @DeveloperApi - protected def parentClassLoader: ClassLoader = - SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) - - /* A single class loader is used for all commands interpreted by this Interpreter. - It would also be possible to create a new class loader for each command - to interpret. The advantages of the current approach are: - - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" - - The main disadvantage is: - - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - private def resetClassLoader() = { - logDebug("Setting new classloader: was " + _classLoader) - _classLoader = null - ensureClassLoader() - } - private final def ensureClassLoader() { - if (_classLoader == null) - _classLoader = makeClassLoader() - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def classLoader: AbstractFileClassLoader = { - ensureClassLoader() - _classLoader - } - private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) { - /** Overridden here to try translating a simple name to the generated - * class name if the original attempt fails. This method is used by - * getResourceAsStream as well as findClass. - */ - override protected def findAbstractFile(name: String): AbstractFile = { - super.findAbstractFile(name) match { - // deadlocks on startup if we try to translate names too early - case null if isInitializeComplete => - generatedName(name) map (x => super.findAbstractFile(x)) orNull - case file => - file - } - } - } - private def makeClassLoader(): AbstractFileClassLoader = - new TranslatingClassLoader(parentClassLoader match { - case null => ScalaClassLoader fromURLs compilerClasspath - case p => - _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl - _runtimeClassLoader - }) - - private def getInterpreterClassLoader() = classLoader - - // Set the current Java "context" class loader to this interpreter's class loader - // NOTE: Exposed to repl package since used by SparkILoopInit - private[repl] def setContextClassLoader() = classLoader.setAsContext() - - /** - * Returns the real name of a class based on its repl-defined name. - * - * ==Example== - * Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * {{{ - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - * }}} - * - * @param simpleName The repl-defined name whose real name to retrieve - * - * @return Some real name if the simple name exists, else None - */ - @DeveloperApi - def generatedName(simpleName: String): Option[String] = { - if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) - else optFlatName(simpleName) - } - - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def flatName(id: String) = optFlatName(id) getOrElse id - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) - - /** - * Retrieves all simple names contained in the current instance. - * - * @return A list of sorted names - */ - @DeveloperApi - def allDefinedNames = definedNameMap.keys.toList.sorted - - private def pathToType(id: String): String = pathToName(newTypeName(id)) - // NOTE: Exposed to repl package since used by SparkILoop - private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id)) - - /** - * Retrieves the full code path to access the specified simple name - * content. - * - * @param name The simple name of the target whose path to determine - * - * @return The full path used to access the specified target (name) - */ - @DeveloperApi - def pathToName(name: Name): String = { - if (definedNameMap contains name) - definedNameMap(name) fullPath name - else name.toString - } - - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) - case _ => () - } - } - None - } - - /** Stubs for work in progress. */ - private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { - for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { - logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) - } - } - - private def handleTermRedefinition(name: TermName, old: Request, req: Request) = { - for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { - // Printing the types here has a tendency to cause assertion errors, like - // assertion failed: fatal: has owner value x, but a class owner is required - // so DBG is by-name now to keep it in the family. (It also traps the assertion error, - // but we don't want to unnecessarily risk hosing the compiler's internal state.) - logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) - } - } - - private def recordRequest(req: Request) { - if (req == null || referencedNameMap == null) - return - - prevRequests += req - req.referencedNames foreach (x => referencedNameMap(x) = req) - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - for { - name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) - oldReq <- definedNameMap get name.companionName - newSym <- req.definedSymbols get name - oldSym <- oldReq.definedSymbols get name.companionName - if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule } - } { - afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")) - replwarn("Companions must be defined together; you may wish to use :paste mode for this.") - } - - // Updating the defined name map - req.definedNames foreach { name => - if (definedNameMap contains name) { - if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) - else handleTermRedefinition(name.toTermName, definedNameMap(name), req) - } - definedNameMap(name) = req - } - } - - private def replwarn(msg: => String) { - if (!settings.nowarnings.value) - printMessage(msg) - } - - private def isParseable(line: String): Boolean = { - beSilentDuring { - try parse(line) match { - case Some(xs) => xs.nonEmpty // parses as-is - case None => true // incomplete - } - catch { case x: Exception => // crashed the compiler - replwarn("Exception in isParseable(\"" + line + "\"): " + x) - false - } - } - } - - private def compileSourcesKeepingRun(sources: SourceFile*) = { - val run = new Run() - reporter.reset() - run compileSources sources.toList - (!reporter.hasErrors, run) - } - - /** - * Compiles specified source files. - * - * @param sources The sequence of source files to compile - * - * @return True if successful, otherwise false - */ - @DeveloperApi - def compileSources(sources: SourceFile*): Boolean = - compileSourcesKeepingRun(sources: _*)._1 - - /** - * Compiles a string of code. - * - * @param code The string of code to compile - * - * @return True if successful, otherwise false - */ - @DeveloperApi - def compileString(code: String): Boolean = - compileSources(new BatchSourceFile("