-
Notifications
You must be signed in to change notification settings - Fork 2
/
mx.io.bucket.iter.R
106 lines (99 loc) · 6.9 KB
/
mx.io.bucket.iter.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
BucketIter <- setRefClass("BucketIter", fields = c("buckets", "bucket.names", "batch.size",
"data.mask.element", "shuffle", "bucket.plan", "bucketID", "epoch", "batch", "batch.per.bucket",
"last.batch.pad", "batch.per.epoch", "seed"),
methods = list(
initialize = function(buckets,
batch.size, data.mask.element = 0, shuffle = FALSE, seed = 123) {
.self$buckets <- buckets
.self$bucket.names <- names(.self$buckets)
.self$batch.size <- batch.size
.self$data.mask.element <- data.mask.element
.self$epoch <- 0
.self$batch <- 0
.self$shuffle <- shuffle
.self$batch.per.bucket <- 0
.self$batch.per.epoch <- 0
.self$bucket.plan <- NULL
.self$bucketID <- NULL
.self$seed <- seed
.self
}, reset = function() {
buckets_nb <- length(bucket.names)
buckets_id <- seq_len(buckets_nb)
buckets.size <- sapply(.self$buckets, function(x) {
tail(dim(x$data), 1)
})
.self$batch.per.bucket <- ceiling(buckets.size/.self$batch.size)
.self$last.batch.pad <- .self$batch.size - buckets.size %% .self$batch.size
.self$last.batch.pad[.self$last.batch.pad == .self$batch.size] <- 0
.self$batch.per.epoch <- sum(.self$batch.per.bucket)
# Number of batches per epoch given the batch.size
.self$batch.per.epoch <- sum(.self$batch.per.bucket)
.self$epoch <- .self$epoch + 1
.self$batch <- 0
if (.self$shuffle) {
set.seed(.self$seed)
bucket_plan_names <- sample(rep(names(.self$batch.per.bucket), times = .self$batch.per.bucket))
.self$bucket.plan <- ave(bucket_plan_names == bucket_plan_names, bucket_plan_names,
FUN = cumsum)
names(.self$bucket.plan) <- bucket_plan_names
# Return first BucketID at reset for initialization of the model
.self$bucketID <- .self$bucket.plan[1]
.self$buckets <- lapply(.self$buckets, function(x) {
shuffle_id <- sample(tail(dim(x$data), 1))
if (length(dim(x$label)) == 0) {
list(data = x$data[, shuffle_id], label = x$label[shuffle_id])
} else {
list(data = x$data[, shuffle_id], label = x$label[, shuffle_id])
}
})
} else {
bucket_plan_names <- rep(names(.self$batch.per.bucket), times = .self$batch.per.bucket)
.self$bucket.plan <- ave(bucket_plan_names == bucket_plan_names, bucket_plan_names,
FUN = cumsum)
names(.self$bucket.plan) <- bucket_plan_names
}
}, iter.next = function() {
.self$batch <- .self$batch + 1
.self$bucketID <- .self$bucket.plan[batch]
return(.self$batch <= .self$batch.per.epoch)
}, value = function() {
# bucketID is a named integer: the integer indicates the batch id for the given
# bucket (used to fetch appropriate samples within the bucket) the name is a
# character containing the sequence length of the bucket (used to unroll the rnn
# to appropriate sequence length)
idx <- (.self$bucketID - 1) * (.self$batch.size) + seq_len(batch.size)
# Reuse first idx for padding
if (bucketID == .self$batch.per.bucket[names(.self$bucketID)] & !.self$last.batch.pad[names(.self$bucketID)] == 0) {
idx <- c(idx[seq_len(.self$batch.size - .self$last.batch.pad[names(.self$bucketID)])], seq_len(.self$last.batch.pad[names(.self$bucketID)]))
}
data <- .self$buckets[[names(.self$bucketID)]]$data[, idx, drop = F]
seq.mask <- as.integer(names(bucketID)) - apply(data==.self$data.mask.element, 2, sum)
if (length(dim(.self$buckets[[names(.self$bucketID)]]$label)) == 0) {
label <- .self$buckets[[names(.self$bucketID)]]$label[idx]
} else {
label <- .self$buckets[[names(.self$bucketID)]]$label[, idx, drop = F]
}
return(list(data = mx.nd.array(data), seq.mask = mx.nd.array(seq.mask),
label = mx.nd.array(label)))
}, num.pad = function() {
if (bucketID == .self$batch.per.bucket[names(.self$bucketID)] & !.self$last.batch.pad[names(.self$bucketID)] == 0){
return(.self$last.batch.pad[names(.self$bucketID)])
} else return(0)
}, finalize = function() {
}))
#
#' Create Bucket Iter
#'
#' @param buckets The data array.
#' @param batch.size The batch size used to pack the array.
#' @param data.mask.element The element to mask
#' @param shuffle Whether shuffle the data
#' @param seed The random seed
#'
#' @export
mx.io.bucket.iter <- function(buckets, batch.size, data.mask.element = 0, shuffle = FALSE,
seed = 123) {
return(BucketIter$new(buckets = buckets, batch.size = batch.size, data.mask.element = data.mask.element,
shuffle = shuffle, seed = seed))
}