Skip to content

Commit 29e1ee2

Browse files
committed
merge with mater, and add a testcase
2 parents b463ac7 + de62ddf commit 29e1ee2

File tree

120 files changed

+2432
-871
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+2432
-871
lines changed

R/pkg/R/DataFrame.R

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,14 +1727,21 @@ setMethod("$", signature(x = "SparkDataFrame"),
17271727
getColumn(x, name)
17281728
})
17291729

1730-
#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped.
1730+
#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}.
1731+
#' If \code{NULL}, the specified Column is dropped.
17311732
#' @rdname select
17321733
#' @name $<-
17331734
#' @aliases $<-,SparkDataFrame-method
17341735
#' @note $<- since 1.4.0
17351736
setMethod("$<-", signature(x = "SparkDataFrame"),
17361737
function(x, name, value) {
1737-
stopifnot(class(value) == "Column" || is.null(value))
1738+
if (class(value) != "Column" && !is.null(value)) {
1739+
if (isAtomicLengthOne(value)) {
1740+
value <- lit(value)
1741+
} else {
1742+
stop("value must be a Column, literal value as atomic in length of 1, or NULL")
1743+
}
1744+
}
17381745

17391746
if (is.null(value)) {
17401747
nx <- drop(x, name)
@@ -1947,10 +1954,10 @@ setMethod("selectExpr",
19471954
#'
19481955
#' @param x a SparkDataFrame.
19491956
#' @param colName a column name.
1950-
#' @param col a Column expression.
1957+
#' @param col a Column expression, or an atomic vector in the length of 1 as literal value.
19511958
#' @return A SparkDataFrame with the new column added or the existing column replaced.
19521959
#' @family SparkDataFrame functions
1953-
#' @aliases withColumn,SparkDataFrame,character,Column-method
1960+
#' @aliases withColumn,SparkDataFrame,character-method
19541961
#' @rdname withColumn
19551962
#' @name withColumn
19561963
#' @seealso \link{rename} \link{mutate}
@@ -1963,11 +1970,16 @@ setMethod("selectExpr",
19631970
#' newDF <- withColumn(df, "newCol", df$col1 * 5)
19641971
#' # Replace an existing column
19651972
#' newDF2 <- withColumn(newDF, "newCol", newDF$col1)
1973+
#' newDF3 <- withColumn(newDF, "newCol", 42)
19661974
#' }
19671975
#' @note withColumn since 1.4.0
19681976
setMethod("withColumn",
1969-
signature(x = "SparkDataFrame", colName = "character", col = "Column"),
1977+
signature(x = "SparkDataFrame", colName = "character"),
19701978
function(x, colName, col) {
1979+
if (class(col) != "Column") {
1980+
if (!isAtomicLengthOne(col)) stop("Literal value must be atomic in length of 1")
1981+
col <- lit(col)
1982+
}
19711983
sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc)
19721984
dataFrame(sdf)
19731985
})

R/pkg/R/SQLContext.R

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ getDefaultSqlSource <- function() {
184184
#'
185185
#' Converts R data.frame or list into SparkDataFrame.
186186
#'
187-
#' @param data an RDD or list or data.frame.
187+
#' @param data a list or data.frame.
188188
#' @param schema a list of column names or named list (StructType), optional.
189+
#' @param samplingRatio Currently not used.
190+
#' @param numPartitions the number of partitions of the SparkDataFrame. Defaults to 1, this is
191+
#' limited by length of the list or number of rows of the data.frame
189192
#' @return A SparkDataFrame.
190193
#' @rdname createDataFrame
191194
#' @export
@@ -195,12 +198,14 @@ getDefaultSqlSource <- function() {
195198
#' df1 <- as.DataFrame(iris)
196199
#' df2 <- as.DataFrame(list(3,4,5,6))
197200
#' df3 <- createDataFrame(iris)
201+
#' df4 <- createDataFrame(cars, numPartitions = 2)
198202
#' }
199203
#' @name createDataFrame
200204
#' @method createDataFrame default
201205
#' @note createDataFrame since 1.4.0
202206
# TODO(davies): support sampling and infer type from NA
203-
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
207+
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0,
208+
numPartitions = NULL) {
204209
sparkSession <- getSparkSession()
205210

206211
if (is.data.frame(data)) {
@@ -233,7 +238,11 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
233238

234239
if (is.list(data)) {
235240
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
236-
rdd <- parallelize(sc, data)
241+
if (!is.null(numPartitions)) {
242+
rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions))
243+
} else {
244+
rdd <- parallelize(sc, data, numSlices = 1)
245+
}
237246
} else if (inherits(data, "RDD")) {
238247
rdd <- data
239248
} else {
@@ -283,14 +292,13 @@ createDataFrame <- function(x, ...) {
283292
dispatchFunc("createDataFrame(data, schema = NULL)", x, ...)
284293
}
285294

286-
#' @param samplingRatio Currently not used.
287295
#' @rdname createDataFrame
288296
#' @aliases createDataFrame
289297
#' @export
290298
#' @method as.DataFrame default
291299
#' @note as.DataFrame since 1.6.0
292-
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
293-
createDataFrame(data, schema)
300+
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) {
301+
createDataFrame(data, schema, samplingRatio, numPartitions)
294302
}
295303

296304
#' @param ... additional argument(s).

R/pkg/R/context.R

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ objectFile <- function(sc, path, minPartitions = NULL) {
9191
#' will write it to disk and send the file name to JVM. Also to make sure each slice is not
9292
#' larger than that limit, number of slices may be increased.
9393
#'
94+
#' In 2.2.0 we are changing how the numSlices are used/computed to handle
95+
#' 1 < (length(coll) / numSlices) << length(coll) better, and to get the exact number of slices.
96+
#' This change affects both createDataFrame and spark.lapply.
97+
#' In the specific one case that it is used to convert R native object into SparkDataFrame, it has
98+
#' always been kept at the default of 1. In the case the object is large, we are explicitly setting
99+
#' the parallism to numSlices (which is still 1).
100+
#'
101+
#' Specifically, we are changing to split positions to match the calculation in positions() of
102+
#' ParallelCollectionRDD in Spark.
103+
#'
94104
#' @param sc SparkContext to use
95105
#' @param coll collection to parallelize
96106
#' @param numSlices number of partitions to create in the RDD
@@ -107,6 +117,8 @@ parallelize <- function(sc, coll, numSlices = 1) {
107117
# TODO: bound/safeguard numSlices
108118
# TODO: unit tests for if the split works for all primitives
109119
# TODO: support matrix, data frame, etc
120+
121+
# Note, for data.frame, createDataFrame turns it into a list before it calls here.
110122
# nolint start
111123
# suppress lintr warning: Place a space before left parenthesis, except in a function call.
112124
if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) {
@@ -128,12 +140,29 @@ parallelize <- function(sc, coll, numSlices = 1) {
128140
objectSize <- object.size(coll)
129141

130142
# For large objects we make sure the size of each slice is also smaller than sizeLimit
131-
numSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
132-
if (numSlices > length(coll))
133-
numSlices <- length(coll)
143+
numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
144+
if (numSerializedSlices > length(coll))
145+
numSerializedSlices <- length(coll)
146+
147+
# Generate the slice ids to put each row
148+
# For instance, for numSerializedSlices of 22, length of 50
149+
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
150+
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
151+
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
152+
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
153+
splits <- if (numSerializedSlices > 0) {
154+
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
155+
# nolint start
156+
start <- trunc((x * length(coll)) / numSerializedSlices)
157+
end <- trunc(((x + 1) * length(coll)) / numSerializedSlices)
158+
# nolint end
159+
rep(start, end - start)
160+
}))
161+
} else {
162+
1
163+
}
134164

135-
sliceLen <- ceiling(length(coll) / numSlices)
136-
slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)])
165+
slices <- split(coll, splits)
137166

138167
# Serialize each slice: obtain a list of raws, or a list of lists (slices) of
139168
# 2-tuples of raws

R/pkg/R/mllib_clustering.R

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
175175
#' @param k number of centers.
176176
#' @param maxIter maximum iteration number.
177177
#' @param initMode the initialization algorithm choosen to fit the model.
178+
#' @param seed the random seed for cluster initialization.
179+
#' @param initSteps the number of steps for the k-means|| initialization mode.
180+
#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0.
181+
#' @param tol convergence tolerance of iterations.
178182
#' @param ... additional argument(s) passed to the method.
179183
#' @return \code{spark.kmeans} returns a fitted k-means model.
180184
#' @rdname spark.kmeans
@@ -204,11 +208,16 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
204208
#' @note spark.kmeans since 2.0.0
205209
#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
206210
setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"),
207-
function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) {
211+
function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random"),
212+
seed = NULL, initSteps = 2, tol = 1E-4) {
208213
formula <- paste(deparse(formula), collapse = "")
209214
initMode <- match.arg(initMode)
215+
if (!is.null(seed)) {
216+
seed <- as.character(as.integer(seed))
217+
}
210218
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
211-
as.integer(k), as.integer(maxIter), initMode)
219+
as.integer(k), as.integer(maxIter), initMode, seed,
220+
as.integer(initSteps), as.numeric(tol))
212221
new("KMeansModel", jobj = jobj)
213222
})
214223

R/pkg/R/utils.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,3 +863,7 @@ basenameSansExtFromUrl <- function(url) {
863863
# then, strip extension by the last '.'
864864
sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename)
865865
}
866+
867+
isAtomicLengthOne <- function(x) {
868+
is.atomic(x) && length(x) == 1
869+
}

R/pkg/inst/tests/testthat/test_mllib_clustering.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,26 @@ test_that("spark.kmeans", {
132132
expect_true(summary2$is.loaded)
133133

134134
unlink(modelPath)
135+
136+
# Test Kmeans on dataset that is sensitive to seed value
137+
col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
138+
col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
139+
col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
140+
cols <- as.data.frame(cbind(col1, col2, col3))
141+
df <- createDataFrame(cols)
142+
143+
model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
144+
initMode = "random", seed = 1, tol = 1E-5)
145+
model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
146+
initMode = "random", seed = 22222, tol = 1E-5)
147+
148+
fitted.model1 <- fitted(model1)
149+
fitted.model2 <- fitted(model2)
150+
# The predicted clusters are different
151+
expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction),
152+
c(0, 1, 2, 3))
153+
expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction),
154+
c(0, 1, 2))
135155
})
136156

137157
test_that("spark.lda with libsvm", {

R/pkg/inst/tests/testthat/test_rdd.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ test_that("aggregateRDD() on RDDs", {
381381
test_that("zipWithUniqueId() on RDDs", {
382382
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
383383
actual <- collectRDD(zipWithUniqueId(rdd))
384-
expected <- list(list("a", 0), list("b", 3), list("c", 1),
385-
list("d", 4), list("e", 2))
384+
expected <- list(list("a", 0), list("b", 1), list("c", 4),
385+
list("d", 2), list("e", 5))
386386
expect_equal(actual, expected)
387387

388388
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,26 @@ test_that("create DataFrame from RDD", {
196196
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
197197
expect_equal(as.list(collect(where(df, df$name == "John"))),
198198
list(name = "John", age = 19L, height = 176.5))
199+
expect_equal(getNumPartitions(toRDD(df)), 1)
200+
201+
df <- as.DataFrame(cars, numPartitions = 2)
202+
expect_equal(getNumPartitions(toRDD(df)), 2)
203+
df <- createDataFrame(cars, numPartitions = 3)
204+
expect_equal(getNumPartitions(toRDD(df)), 3)
205+
# validate limit by num of rows
206+
df <- createDataFrame(cars, numPartitions = 60)
207+
expect_equal(getNumPartitions(toRDD(df)), 50)
208+
# validate when 1 < (length(coll) / numSlices) << length(coll)
209+
df <- createDataFrame(cars, numPartitions = 20)
210+
expect_equal(getNumPartitions(toRDD(df)), 20)
211+
212+
df <- as.DataFrame(data.frame(0))
213+
expect_is(df, "SparkDataFrame")
214+
df <- createDataFrame(list(list(1)))
215+
expect_is(df, "SparkDataFrame")
216+
df <- as.DataFrame(data.frame(0), numPartitions = 2)
217+
# no data to partition, goes to 1
218+
expect_equal(getNumPartitions(toRDD(df)), 1)
199219

200220
setHiveContext(sc)
201221
sql("CREATE TABLE people (name string, age double, height float)")
@@ -213,7 +233,8 @@ test_that("createDataFrame uses files for large objects", {
213233
# To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value
214234
conf <- callJMethod(sparkSession, "conf")
215235
callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100")
216-
df <- suppressWarnings(createDataFrame(iris))
236+
df <- suppressWarnings(createDataFrame(iris, numPartitions = 3))
237+
expect_equal(getNumPartitions(toRDD(df)), 3)
217238

218239
# Resetting the conf back to default value
219240
callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10))
@@ -1001,6 +1022,17 @@ test_that("select operators", {
10011022
expect_equal(columns(df), c("name", "age", "age2"))
10021023
expect_equal(count(where(df, df$age2 == df$age * 2)), 2)
10031024

1025+
df$age2 <- 21
1026+
expect_equal(columns(df), c("name", "age", "age2"))
1027+
expect_equal(count(where(df, df$age2 == 21)), 3)
1028+
1029+
df$age2 <- c(22)
1030+
expect_equal(columns(df), c("name", "age", "age2"))
1031+
expect_equal(count(where(df, df$age2 == 22)), 3)
1032+
1033+
expect_error(df$age3 <- c(22, NA),
1034+
"value must be a Column, literal value as atomic in length of 1, or NULL")
1035+
10041036
# Test parameter drop
10051037
expect_equal(class(df[, 1]) == "SparkDataFrame", T)
10061038
expect_equal(class(df[, 1, drop = T]) == "Column", T)
@@ -1778,6 +1810,13 @@ test_that("withColumn() and withColumnRenamed()", {
17781810
expect_equal(length(columns(newDF)), 2)
17791811
expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32)
17801812

1813+
newDF <- withColumn(df, "age", 18)
1814+
expect_equal(length(columns(newDF)), 2)
1815+
expect_equal(first(newDF)$age, 18)
1816+
1817+
expect_error(withColumn(df, "age", list("a")),
1818+
"Literal value must be atomic in length of 1")
1819+
17811820
newDF2 <- withColumnRenamed(df, "age", "newerAge")
17821821
expect_equal(length(columns(newDF2)), 2)
17831822
expect_equal(columns(newDF2)[1], "newerAge")

0 commit comments

Comments
 (0)