Skip to content

Commit 9714fa5

Browse files
committed
[SPARK-25234][SPARKR] avoid integer overflow in parallelize
## What changes were proposed in this pull request? `parallelize` uses integer multiplication to determine the split indices. It might cause integer overflow. ## How was this patch tested? unit test Closes #22225 from mengxr/SPARK-25234. Authored-by: Xiangrui Meng <[email protected]> Signed-off-by: Xiangrui Meng <[email protected]>
1 parent f8346d2 commit 9714fa5

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

R/pkg/R/context.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) {
138138

139139
sizeLimit <- getMaxAllocationLimit(sc)
140140
objectSize <- object.size(coll)
141+
len <- length(coll)
141142

142143
# For large objects we make sure the size of each slice is also smaller than sizeLimit
143-
numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
144-
if (numSerializedSlices > length(coll))
145-
numSerializedSlices <- length(coll)
144+
numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit)))
146145

147146
# Generate the slice ids to put each row
148147
# For instance, for numSerializedSlices of 22, length of 50
@@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) {
153152
splits <- if (numSerializedSlices > 0) {
154153
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
155154
# nolint start
156-
start <- trunc((x * length(coll)) / numSerializedSlices)
157-
end <- trunc(((x + 1) * length(coll)) / numSerializedSlices)
155+
start <- trunc((as.numeric(x) * len) / numSerializedSlices)
156+
end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices)
158157
# nolint end
159158
rep(start, end - start)
160159
}))

R/pkg/tests/fulltests/test_context.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", {
240240
unlink(path, recursive = TRUE)
241241
sparkR.session.stop()
242242
})
243+
244+
test_that("SPARK-25234: parallelize should not have integer overflow", {
245+
sc <- sparkR.sparkContext(master = sparkRTestMaster)
246+
# 47000 * 47000 exceeds integer range
247+
parallelize(sc, 1:47000, 47000)
248+
sparkR.session.stop()
249+
})

0 commit comments

Comments
 (0)