Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Edit shape.array doc and some style improvements #12162

Merged
merged 3 commits into from
Aug 29, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions R-package/R/initializer.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @param name the name of the variable.
#' @param shape the shape of the array to be generated.
#'
mx.init.internal.default <- function(name, shape, ctx, allow.unknown=FALSE) {
mx.init.internal.default <- function(name, shape, ctx, allow.unknown = FALSE) {
if (endsWith(name, "bias")) return (mx.nd.zeros(shape))
if (endsWith(name, "gamma")) return (mx.nd.ones(shape))
if (endsWith(name, "beta")) return (mx.nd.zeros(shape))
Expand All @@ -19,7 +19,7 @@ mx.init.internal.default <- function(name, shape, ctx, allow.unknown=FALSE) {
#'
#' @export
mx.init.uniform <- function(scale) {
function(name, shape, ctx, allow.unknown=FALSE) {
function(name, shape, ctx, allow.unknown = FALSE) {
if (!endsWith(name, "weight")) {
return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
}
Expand All @@ -33,7 +33,7 @@ mx.init.uniform <- function(scale) {
#'
#' @export
mx.init.normal <- function(sd) {
function(name, shape, ctx, allow.unknown=FALSE) {
function(name, shape, ctx, allow.unknown = FALSE) {
if (!endsWith(name, "weight")) {
return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
}
Expand All @@ -59,15 +59,15 @@ mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg",
return (mx.init.internal.default(name = name, shape = shape, allow.unknown = allow.unknown))
}

fan_out = shape[length(shape)]
fan_in = prod(shape[-length(shape)])
fan_out <- shape[length(shape)]
fan_in <- prod(shape[-length(shape)])
factor_val <- switch(factor_type,
"avg" = (fan_in + fan_out) / 2,
"in" = fan_in,
"out" = fan_out,
stop("Not supported factor type. See usage of function mx.init.Xavier"))

scale = sqrt(magnitude / factor_val)
scale <- sqrt(magnitude / factor_val)

if (rnd_type == "uniform"){
return(mx.nd.random.uniform(low = -scale, high = scale, shape = shape))
Expand All @@ -83,14 +83,16 @@ mx.init.Xavier <- function(rnd_type = "uniform", factor_type = "avg",
#' Create initialization of argument like arg.array
#'
#' @param initializer The initializer.
#' @param shape.array named-list The shape of the weights
#' @param shape.array A named list that represents the shape of the weights
#' @param ctx mx.context The context of the weights
#' @param skip.unknown Whether skip the unknown weight types
#' @export
mx.init.create <- function(initializer, shape.array, ctx=NULL, skip.unknown=TRUE) {
mx.init.create <- function(initializer, shape.array, ctx = NULL, skip.unknown = TRUE) {
if (length(shape.array) == 0) return(list())
names = names(shape.array)
ret <- lapply(seq_along(names), function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown=skip.unknown))
names <- names(shape.array)
ret <- lapply(
seq_along(names),
function(i) initializer(names[[i]], shape.array[[i]], ctx, allow.unknown = skip.unknown))
names(ret) <- names
if (skip.unknown) {
ret <- mx.util.filter.null(ret)
Expand Down