Skip to content

Commit

Permalink
Edit shape.array doc and some style improvements (apache#12162)
Browse files Browse the repository at this point in the history
* Edit shape.array doc and some style improvements

* Trigger CI

* Trigger CI
  • Loading branch information
terrytangyuan authored and anirudh2290 committed Sep 19, 2018
1 parent c417bbb commit f6aa533
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions R-package/R/initializer.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @param name the name of the variable.
#' @param shape the shape of the array to be generated.
#'
mx.init.internal.default <- function(name, shape, ctx, allow.unknown=FALSE) {
mx.init.internal.default <- function(name, shape, ctx, allow.unknown = FALSE) {
if (endsWith(name, "bias")) return (mx.nd.zeros(shape))
if (endsWith(name, "gamma")) return (mx.nd.ones(shape))
if (endsWith(name, "beta")) return (mx.nd.zeros(shape))
Expand All @@ -19,7 +19,7 @@ mx.init.internal.default <- function(name, shape, ctx, allow.unknown=FALSE) {
#'
#' @export
mx.init.uniform <- function(scale) {
function(name, shape, ctx, allow.unknown=FALSE) {
function(name, shape, ctx, allow.unknown = FALSE) {
if (!endsWith(name, "weight")) {
return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
}
Expand All @@ -33,7 +33,7 @@ mx.init.uniform <- function(scale) {
#'
#' @export
mx.init.normal <- function(sd) {
function(name, shape, ctx, allow.unknown=FALSE) {
function(name, shape, ctx, allow.unknown = FALSE) {
if (!endsWith(name, "weight")) {
return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
}
Expand All @@ -59,15 +59,15 @@ mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg",
return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
}

fan_out = shape[length(shape)]
fan_in = prod(shape[-length(shape)])
fan_out <- shape[length(shape)]
fan_in <- prod(shape[-length(shape)])
factor_val <- switch(factor_type,
"avg" = (fan_in + fan_out) / 2,
"in" = fan_in,
"out" = fan_out,
stop("Not supported factor type. See usage of function mx.init.Xavier"))

scale = sqrt(magnitude / factor_val)
scale <- sqrt(magnitude / factor_val)

if (rnd_type == "uniform"){
return(mx.nd.random.uniform(low = -scale, high = scale, shape = shape))
Expand All @@ -83,14 +83,16 @@ mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg",
#' Create initialization of argument like arg.array
#'
#' @param initializer The initializer.
#' @param shape.array named-list The shape of the weights
#' @param shape.array A named list that represents the shape of the weights
#' @param ctx mx.context The context of the weights
#' @param skip.unknown Whether skip the unknown weight types
#' @export
mx.init.create <- function(initializer, shape.array, ctx=NULL, skip.unknown=TRUE) {
mx.init.create <- function(initializer, shape.array, ctx = NULL, skip.unknown = TRUE) {
if (length(shape.array) == 0) return(list())
names = names(shape.array)
ret <- lapply(seq_along(names), function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown=skip.unknown))
names <- names(shape.array)
ret <- lapply(
seq_along(names),
function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown = skip.unknown))
names(ret) <- names
if (skip.unknown) {
ret <- mx.util.filter.null(ret)
Expand Down

0 comments on commit f6aa533

Please sign in to comment.