Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
modify tests and docs, have segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
hetong007 committed Oct 11, 2015
1 parent e7efd31 commit a777402
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 9 deletions.
10 changes: 10 additions & 0 deletions R-package/R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ init.context.default <- function() {
#'
#' Set/Get default context for array creation.
#'
#' @rdname mx.ctx
#'
#' @param new, optional takes \code{mx.cpu()} or \code{mx.gpu(id)}, new default ctx.
#'
#' @export
mx.ctx.default <- function(new = NULL) {
if (!is.null(new)) {
Expand All @@ -17,6 +20,13 @@ mx.ctx.default <- function(new = NULL) {
return (mx.ctx.internal.default.value)
}

# TODO need examples

#' @rdname mx.ctx
#'
#' @return Logical indicator
#'
#' @export
is.mx.context <- function(x) {
class(x) == "MXContext"
}
17 changes: 11 additions & 6 deletions R-package/R/initializer.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# TODO(tong, KK) better R alternatives?
mx.util.str.endswith <- function(name, suffix) {
slen <- nchar(suffix)
nlen <- nchar(name)
if (slen > nlen) return (FALSE)
nsuf <- substr(name, nlen - slen + 1, nlen)
return (nsuf == suffix)
# slen <- nchar(suffix)
# nlen <- nchar(name)
# if (slen > nlen) return (FALSE)
# nsuf <- substr(name, nlen - slen + 1, nlen)
# return (nsuf == suffix)
ptrn = paste0(suffix, "\\b")
return(grepl(ptrn, name))
}

#' Internal default value initialization scheme.
Expand All @@ -20,7 +21,9 @@ mx.init.internal.default <- function(name, shape, ctx) {
}

#' Create a initializer that initialize the weight with uniform [-scale, scale]
#'
#' @param scale The scale of uniform distribution
#'
#' @export
mx.init.uniform <- function(scale) {
function(name, shape, ctx) {
Expand All @@ -32,7 +35,9 @@ mx.init.uniform <- function(scale) {
}

#' Create a initializer that initialize the weight with normal(0, sd)
#'
#' @param scale The scale of uniform distribution
#'
#' @export
mx.init.normal <- function(sd) {
function(name, shape, ctx) {
Expand Down
38 changes: 37 additions & 1 deletion R-package/R/ndarray.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#'
#' @param filename the filename (including the path)
#'
#' @examples
#' mat = mx.nd.array(1:3)
#' mx.nd.save(mat, 'temp.mat')
#' mat2 = mx.nd.load('temp.mat')
#' as.array(mat)
#' as.array(mat2)
#'
#' @export
mx.nd.load <- function(filename) {
filename <- path.expand(filename)
Expand All @@ -13,6 +20,13 @@ mx.nd.load <- function(filename) {
#' @param ndarray the \code{mx.nd.array} object
#' @param filename the filename (including the path)
#'
#' @examples
#' mat = mx.nd.array(1:3)
#' mx.nd.save(mat, 'temp.mat')
#' mat2 = mx.nd.load('temp.mat')
#' as.array(mat)
#' as.array(mat2)
#'
#' @export
mx.nd.save <- function(ndarray, filename) {
filename <- path.expand(filename)
Expand All @@ -29,6 +43,14 @@ mx.nd.internal.empty <- function(shape, ctx=NULL) {
#' @param shape the dimension of the \code{mx.nd.array}
#' @param ctx optional The context device of the array. mx.ctx.default() will be used in default.
#'
#' @examples
#' mat = mx.nd.zeros(10)
#' as.array(mat)
#' mat2 = mx.nd.zeros(c(5,5))
#' as.array(mat)
#' mat3 = mx.nd.zeroes(c(3,3,3))
#' as.array(mat3)
#'
#' @export
mx.nd.zeros <- function(shape, ctx=NULL) {
ret <- mx.nd.internal.empty(shape, ctx)
Expand All @@ -39,7 +61,15 @@ mx.nd.zeros <- function(shape, ctx=NULL) {
#'
#' @param shape the dimension of the \code{mx.nd.array}
#' @param ctx optional The context device of the array. mx.ctx.default() will be used in default.
#'
#'
#' @examples
#' mat = mx.nd.ones(10)
#' as.array(mat)
#' mat2 = mx.nd.ones(c(5,5))
#' as.array(mat)
#' mat3 = mx.nd.ones(c(3,3,3))
#' as.array(mat3)
#'
#' @export
mx.nd.ones <- function(shape, ctx=NULL) {
ret <- mx.nd.internal.empty(shape, ctx)
Expand Down Expand Up @@ -87,6 +117,12 @@ is.MXNDArray <- function(x) {
#'
#' @return Logical indicator
#'
#' @examples
#' mat = mx.nd.array(1:10)
#' is.mx.nd.array(mat)
#' mat2 = 1:10
#' is.mx.nd.array(mat2)
#'
#' @export
is.mx.nd.array <- function(src.array) {
is.MXNDArray(src.array)
Expand Down
30 changes: 29 additions & 1 deletion R-package/R/random.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#' Set the seed used by mxnet device-specific random number generators.
#'
#' @details
#' We have a specific reason why \code{mx.set.seed} is introduced,
#' instead of simply use \code{set.seed}.
#'
Expand All @@ -12,7 +13,16 @@
#' So we introduced \code{mx.set.seed} for mxnet specific device random numbers.
#'
#' @param seed the seed value to the device random number generators.
#'
#'
#' @examples
#'
#' mx.set.seed(0)
#' as.array(mx.runif(2))
#' # 0.5488135 0.5928446
#' mx.set.seed(0)
#' as.array(mx.rnorm(2))
#' # 2.212206 1.163079
#'
#' @export
mx.set.seed <- function(seed) {
mx.internal.set.seed(seed)
Expand All @@ -25,6 +35,15 @@ mx.set.seed <- function(seed) {
#' @param max numeric, The upper bound of distribution.
#' @param ctx, optional The context device of the array. mx.ctx.default() will be used in default.
#'
#' @examples
#'
#' mx.set.seed(0)
#' as.array(mx.runif(2))
#' # 0.5488135 0.5928446
#' mx.set.seed(0)
#' as.array(mx.rnorm(2))
#' # 2.212206 1.163079
#'
#' @export
mx.runif <- function(shape, min=0, max=1, ctx=NULL) {
if (!is.numeric(min)) stop("mx.rnorm only accept numeric min")
Expand All @@ -40,6 +59,15 @@ mx.runif <- function(shape, min=0, max=1, ctx=NULL) {
#' @param sd numeric, The standard deviations.
#' @param ctx, optional The context device of the array. mx.ctx.default() will be used in default.
#'
#' @examples
#'
#' mx.set.seed(0)
#' as.array(mx.runif(2))
#' # 0.5488135 0.5928446
#' mx.set.seed(0)
#' as.array(mx.rnorm(2))
#' # 2.212206 1.163079
#'
#' @export
mx.rnorm <- function(shape, mean=0, sd=1, ctx=NULL) {
if (!is.numeric(mean)) stop("mx.rnorm only accept numeric mean")
Expand Down
38 changes: 37 additions & 1 deletion R-package/R/symbol.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@

#' Save an mx.symbol object
#'
#' @param symbol the \code{mx.symbol} object
#' @param filename the filename (including the path)
#'
#' @examples
#' data = mx.symbol.Variable('data')
#' mx.symbol.save(data, 'temp.symbol')
#' data2 = mx.symbol.load('temp.symbol')
#'
#' @export
mx.symbol.save <-function(symbol, filename) {
filename <- path.expand(filename)
symbol$save(filename)
}

#' Load an mx.symbol object
#'
#' @param filename the filename (including the path)
#'
#' @examples
#' data = mx.symbol.Variable('data')
#' mx.symbol.save(data, 'temp.symbol')
#' data2 = mx.symbol.load('temp.symbol')
#'
#' @export
mx.symbol.load <-function(filename) {
filename <- path.expand(filename)
mx.symbol.load(filename)
Expand Down Expand Up @@ -73,6 +93,22 @@ is.MXSymbol <- function(x) {
inherits(x, "Rcpp_MXSymbol")
}

#' Judge if an object is mx.symbol
#'
#' @return Logical indicator
#'
#' @examples
#' mat = mx.nd.array(1:10)
#' is.mx.nd.array(mat)
#' mat2 = 1:10
#' is.mx.nd.array(mat2)
#'
#' @export

is.mx.symbol <- function(x) {
is.MXSymbol(x)
}

arguments <- function(x) {
if (!is.MXSymbol(x))
stop("only for MXSymbol type")
Expand Down
12 changes: 12 additions & 0 deletions R-package/tests/testthat/test_ndarray.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,15 @@ test_that("element-wise calculation for matrix", {
expect_equal(x-x, as.array(mat-mat))
expect_equal(as.array(1-mat), as.array(1-mat))
})

test_that("ndarray ones, zeros, save and load", {
expect_equal(rep(0, 10), as.array(mx.nd.zeros(10)))
expect_equal(matrix(0, 10, 5), as.array(mx.nd.zeros(10, 5)))
expect_equal(rep(1, 10), as.array(mx.nd.ones(10)))
expect_equal(matrix(1, 10, 5), as.array(mx.nd.ones(10, 5)))
mat = mx.nd.array(1:20)
mx.nd.save(mat, 'temp.mat')
mat2 = mx.nd.load('temp.mat')
expect_true(is.mx.nd.array(mat2))
expect_equal(as.array(mat), as.array(mat2))
})
2 changes: 2 additions & 0 deletions R-package/tests/testthat/test_symbol.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ test_that("basic symbol operation", {
net2 = mx.symbol.FullyConnected(data=net2, name='fc4', num_hidden=20)

composed = mx.apply(net2, fc3_data=net1, name='composed')

expect_equal(arguments(composed), c('data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'fc3_weight', 'fc3_bias', 'fc4_weight', 'fc4_bias'))
})


0 comments on commit a777402

Please sign in to comment.