Skip to content

Commit 0d07770

Browse files
author
cafreeman
committed
Added limit and updated take
The `limit` function now calls the native DataFrame `limit` and returns a new DataFrame. `take` now collects the results of `limit` and returns a `data.frame` instead of converting to an RRDD and using the RDD `take` method.
1 parent 301d8e5 commit 0d07770

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ exportClasses("DataFrame")
8383

8484
exportMethods("printSchema",
8585
"registerTempTable",
86-
"toRDD")
86+
"toRDD",
87+
"limit")
8788

8889
export("jsonFile",
8990
"parquetFile",

pkg/R/DataFrame.R

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,10 @@ setMethod("registerTempTable",
110110
setMethod("count",
111111
signature(x = "DataFrame"),
112112
function(x) {
113-
sdf <- x@sdf
114-
callJMethod(sdf, "count")
113+
callJMethod(x@sdf, "count")
115114
})
116115

117-
# Collects all the elements of a Spark DataFrame and coerces them into an R data.frame.
116+
#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame.
118117

119118
#' @rdname collect-methods
120119
#' @export
@@ -146,7 +145,35 @@ setMethod("collect",
146145
dfOut
147146
})
148147

149-
# Take the first NUM elements in a DataFrame and return a named list for each row.
148+
#' Limit
149+
#'
150+
#' Limit the resulting DataFrame to the number of rows specified.
151+
#'
152+
#' @param df A SparkSQL DataFrame
153+
#' @param num The number of rows to return
154+
#' @return A new DataFrame containing the number of rows specified.
155+
#'
156+
#' @rdname limit
157+
#' @export
158+
#' @examples
159+
#' \dontrun{
160+
#' sc <- sparkR.init()
161+
#' sqlCtx <- sparkRSQL.init(sc)
162+
#' path <- "path/to/file.json"
163+
#' df <- jsonFile(sqlCtx, path)
164+
#' limitedDF <- limit(df, 10)
165+
#' }
166+
167+
setGeneric("limit", function(df, num) {standardGeneric("limit") })
168+
169+
setMethod("limit",
170+
signature(df = "DataFrame", num = "numeric"),
171+
function(df, num) {
172+
res <- callJMethod(df@sdf, "limit", as.integer(num))
173+
dataFrame(res)
174+
})
175+
176+
# Take the first NUM elements in a DataFrame and return a the results as a data.frame
150177

151178
#' @rdname take
152179
#' @export
@@ -162,8 +189,8 @@ setMethod("collect",
162189
setMethod("take",
163190
signature(rdd = "DataFrame", num = "numeric"),
164191
function(rdd, num) {
165-
rddIn <- toRDD(rdd)
166-
take(rddIn, num)
192+
limited <- limit(rdd, num)
193+
collect(limited)
167194
})
168195

169196
#' toRDD()
@@ -189,8 +216,12 @@ setMethod("toRDD",
189216
signature(df = "DataFrame"),
190217
function(df) {
191218
jrdd <- callJStatic("edu.berkeley.cs.amplab.sparkr.SQLUtils", "dfToRowRDD", df@sdf)
192-
names <- callJMethod(df@sdf, "columns")
193-
RDD(jrdd, serializedMode = "row", colNames = names)
219+
colNames <- callJMethod(df@sdf, "columns")
220+
rdd <- RDD(jrdd, serializedMode = "row")
221+
lapply(rdd, function(row) {
222+
names(row) <- colNames
223+
row
224+
})
194225
})
195226

196227
############################## RDD Map Functions ##################################

pkg/inst/tests/test_sparkSQL.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,17 @@ test_that("collect() returns a data.frame", {
113113
expect_true(ncol(rdf) == 2)
114114
})
115115

116+
test_that("limit() returns DataFrame with the correct number of rows", {
117+
df <- jsonFile(sqlCtx, jsonPath)
118+
dfLimited <- limit(df, 2)
119+
expect_true(inherits(dfLimited, "DataFrame"))
120+
expect_true(count(dfLimited) == 2)
121+
})
122+
116123
test_that("collect() and take() on a DataFrame return the same number of rows and columns", {
117124
df <- jsonFile(sqlCtx, jsonPath)
118-
expect_true(nrow(collect(df)) == length(take(df, 10)))
119-
expect_true(ncol(collect(df)) == length(take(df, 10)[[1]]))
125+
expect_true(nrow(collect(df)) == nrow(take(df, 10)))
126+
expect_true(ncol(collect(df)) == ncol(take(df, 10)))
120127
})
121128

122129
test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", {

0 commit comments

Comments
 (0)