Skip to content

Commit

Permalink
Merge pull request apache#28 from vchuravy/vc/fix_xavier
Browse files Browse the repository at this point in the history
Different Xaiver variants
  • Loading branch information
pluskid committed Nov 17, 2015
2 parents aea096f + 4f09440 commit 56bd4f7
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions src/initializer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,60 @@ end
The initializer documented in the paper [Bengio and Glorot 2010]: *Understanding
the difficulty of training deep feedforward neuralnetworks*.
There are several different version of the XaiverInitializer used in the wild.
The general idea is that the variance of the initialization distribution is controlled
by the dimensionality of the input and output. As a distribution one can either choose
a normal distribution with μ = 0 and σ² or a uniform distribution from -σ to σ.
Several different ways of calculating the variance are given in the literature or are
used by various libraries.
- original [Bengio and Glorot 2010]: σ² = 2 / (in + out)
- msra [K. He, X. Zhang, S. Ren, and J. Sun 2015]: σ² = 2 / in
- caffe_avg: 6 / (in + out)
- caffe_in: 3 / in
- caffe_out: 3 / out
- mxnet: 3 / (in + out)
Distribution and variant can be chosen by enums (prefixed by xv_).
As an example take mx.XaiverInitializer(distribution = mx.xv_normal, variant = mx.xv_mxnet),
which is currently the default.
=#

@enum XaiverDistribution xv_uniform xv_normal
@enum XaiverVariant xv_original xv_mrsa xv_caffe_avg xv_caffe_in zv_caffe_out xv_mxnet

immutable XaiverInitializer <: AbstractInitializer
distribution :: XaiverDistribution
variant :: XaiverVariant
end
XaiverInitializer(; distribution = xv_uniform, variant = xv_mxnet) = XaiverInitializer(distribution, variant)

function _init_weight(self :: NormalInitializer, name :: Base.Symbol, array :: NDArray)
function _init_weight(self :: XaiverInitializer, name :: Base.Symbol, array :: NDArray)
dims = size(array)
fan_in = prod(dims[2:end])
fan_out = dims[1]
scale = sqrt(3 / (fan_in + fan_out))
rand!(-scale, scale, array)

if self.distribution == xv_uniform
func(σ, data) = rand!(-σ, σ, data)
elseif self.distribution == xv_normal
func(σ, data) = randn!(0.0, σ, data)
end

if self.variant == xv_caffe_avg
var = 6 / (fan_in + fan_out)
elseif self.variant == xv_caffe_in
var = 3 / fan_in
elseif self.variant == xv_caffe_out
var = 3 / fan_out
elseif self.variant == xv_mrsa
var = 2 / fan_in
elseif self.variant == xv_original
var = 2 / (fan_in + fan_out)
elseif self.variant == xv_mxnet
var = 3 / (fan_in + fan_out)
end

func(var, array)
end

0 comments on commit 56bd4f7

Please sign in to comment.