diff --git a/R-package/R/initializer.R b/R-package/R/initializer.R index 40712432d8b6..bb81a285beaa 100644 --- a/R-package/R/initializer.R +++ b/R-package/R/initializer.R @@ -3,7 +3,7 @@ #' @param name the name of the variable. #' @param shape the shape of the array to be generated. #' -mx.init.internal.default <- function(name, shape, ctx, allow.unknown=FALSE) { +mx.init.internal.default <- function(name, shape, ctx, allow.unknown = FALSE) { if (endsWith(name, "bias")) return (mx.nd.zeros(shape)) if (endsWith(name, "gamma")) return (mx.nd.ones(shape)) if (endsWith(name, "beta")) return (mx.nd.zeros(shape)) @@ -19,7 +19,7 @@ mx.init.internal.default <- function(name, shape, ctx, allow.unknown=FALSE) { #' #' @export mx.init.uniform <- function(scale) { - function(name, shape, ctx, allow.unknown=FALSE) { + function(name, shape, ctx, allow.unknown = FALSE) { if (!endsWith(name, "weight")) { return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown)) } @@ -33,7 +33,7 @@ mx.init.uniform <- function(scale) { #' #' @export mx.init.normal <- function(sd) { - function(name, shape, ctx, allow.unknown=FALSE) { + function(name, shape, ctx, allow.unknown = FALSE) { if (!endsWith(name, "weight")) { return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown)) } @@ -59,15 +59,15 @@ mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg", return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown)) } - fan_out = shape[length(shape)] - fan_in = prod(shape[-length(shape)]) + fan_out <- shape[length(shape)] + fan_in <- prod(shape[-length(shape)]) factor_val <- switch(factor_type, "avg" = (fan_in + fan_out) / 2, "in" = fan_in, "out" = fan_out, stop("Not supported factor type. See usage of function mx.init.Xavier")) - scale = sqrt(magnitude / factor_val) + scale <- sqrt(magnitude / factor_val) if (rnd_type == "uniform"){ return(mx.nd.random.uniform(low = -scale, high = scale, shape = shape)) @@ -83,14 +83,16 @@ mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg", #' Create initialization of argument like arg.array #' #' @param initializer The initializer. -#' @param shape.array named-list The shape of the weights +#' @param shape.array A named list that represents the shape of the weights #' @param ctx mx.context The context of the weights #' @param skip.unknown Whether skip the unknown weight types #' @export -mx.init.create <- function(initializer, shape.array, ctx=NULL, skip.unknown=TRUE) { +mx.init.create <- function(initializer, shape.array, ctx = NULL, skip.unknown = TRUE) { if (length(shape.array) == 0) return(list()) - names = names(shape.array) - ret <- lapply(seq_along(names), function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown=skip.unknown)) + names <- names(shape.array) + ret <- lapply( + seq_along(names), + function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown = skip.unknown)) names(ret) <- names if (skip.unknown) { ret <- mx.util.filter.null(ret)