Skip to content

Commit

Permalink
Merge pull request apache#13 from terrytangyuan/terry
Browse files Browse the repository at this point in the history
Random and Initializer
  • Loading branch information
yzhliu committed Dec 29, 2015
2 parents 1d988d5 + acb82e2 commit aea7c13
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 1 deletion.
133 changes: 133 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.NDArray.{array, zeros, ones}


/**
*
* Base class for Initializer.
*
* @author Yuan Tang
*/
abstract class Initializer {

/**
* Initialize an Initializer
*
* @param name name of corrosponding ndarray
* @param arr ndarray to be Initialized
*/
def apply(name: String, arr: NDArray): Unit = {

if (name.startsWith("upsampling")) {
_initBilinear(name, arr)
} else if (name.endsWith("bias")) {
_initBias(name, arr)
} else if (name.endsWith("gamma")) {
_initGamma(name, arr)
} else if (name.endsWith("beta")) {
_initBeta(name, arr)
} else if (name.endsWith("weight")) {
_initWeight(name, arr)
} else if (name.endsWith("moving_mean")) {
_initZero(name, arr)
} else if (name.endsWith("moving_var")) {
_initZero(name, arr)
} else if (name.endsWith("moving_avg")) {
_initZero(name, arr)
} else {
throw new IllegalArgumentException(s"Unknown initialization pattern for ${name}.")
}
}

def _initBilinear(name: String, arr: NDArray): Unit = {
val weight = Array.fill[Float](arr.size)(0.0f)
val shape = arr.shape
val f = shape(3) / 2.0f
val c = (2 * f - 1 - f % 2) / (2.0f * f)

(0 to (arr.size)).foreach { i =>
val x = i % shape(3)
val y = (i / shape(3)) % shape(2)
weight(i) = (1 - math.abs(x / f - c)) * (1 - math.abs(y / f - c))
}

arr.set(array(weight))
}

def _initZero(name: String, arr: NDArray): Unit = {
arr.set(0f)
}

def _initBias(name: String, arr: NDArray): Unit = {
arr.set(0f)
}

def _initGamma(name: String, arr: NDArray): Unit = {
arr.set(1f)
}

def _initBeta(name: String, arr: NDArray): Unit = {
arr.set(0f)
}

def _initWeight(name: String, arr: NDArray): Unit
}


/**
* Initialize the weight with uniform [-scale, scale]
*
* @param scale The scale of uniform distribution
*/
class Uniform(protected val scale: Float=0.07f) extends Initializer {
override def _initWeight(name: String, arr: NDArray): Unit = {
Random.uniform(-scale, scale, out=arr)
}
}


/**
* Initialize the weight with normal(0, sigma)
*
* @param sigma Standard deviation for gaussian distribution.
*/
class Normal(protected val sigma: Float=0.01f) extends Initializer {
override def _initWeight(name: String, arr: NDArray): Unit = {
Random.normal(0, sigma, out=arr)
}
}


/**
* Initialize the weight with Xavier or similar initialization scheme.
*
* @param rndType Options are: "gaussian" or "uniform"
* @param factorType Options are: "avg", "in", "out"
* @param magnitude scale of random number range
*/
class Xavier(protected val rndType: String ="uniform",
protected val factorType: String ="avg",
protected val magnitude: Int = 3) extends Initializer {

override def _initWeight(name: String, arr: NDArray): Unit = {
val shape = arr.shape
val fanIn = shape.slice(1, shape.length).product
val fanOut = shape(0)
var factor = 1

factor = factorType match {
case "avg" => (fanIn + fanOut) / 2
case "in" => fanIn
case "out" => fanOut
case _ => throw new IllegalArgumentException("Incorrect factor type")
}
val scale = math.sqrt(magnitude / factor).toFloat

rndType match {
case "uniform" => Random.uniform(-scale, scale, out=arr)
case "normal" => Random.normal(0, scale, out=arr)
case _ => throw new IllegalArgumentException("Unknown random type")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,19 @@ object NDArray {
NDArray._unaryNDArrayFunction("norm", src)
}

// TODO
def _randomUniform(low: Float, high: Float, out: NDArray) = ???

def _randomGaussian(mean: Float, stdvar: Float, out: NDArray) = ???


/**
* Create a new NDArray that copies content from source_array.
* @param sourceArr Source data to create NDArray from.
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def array(sourceArr: Array[Int], ctx: Context=null): NDArray = ???
def array(sourceArr: Array[Float], ctx: Context=null): NDArray = ???

/**
* Load ndarray from binary file.
Expand Down
71 changes: 71 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.NDArray.{_randomUniform, _randomGaussian, empty}

/**
* Random Number interface of mxnet.
* @author Yuan Tang
*/
object Random {
/**
* Generate uniform distribution in [low, high) with shape.
*
* @param low The lower bound of distribution.
* @param high The upper bound of distribution.
* @param shape Output shape of the NDArray generated.
* @param ctx Context of output NDArray, will use default context if not specified.
* @param out Output place holder
* @return The result NDArray with generated result.
*/
def uniform(low: Float, high: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = {
var outCopy = out
if (outCopy != null) {
require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.")
} else {
require(shape != null, "shape is required when out is not specified")
outCopy = empty(shape, ctx)
}
return _randomUniform(low, high, outCopy)
}


/**
* Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
*
* @param mean The mean of the normal distribution.
* @param stdvar The standard deviation of normal distribution.
* @param shape Output shape of the NDArray generated.
* @param ctx Context of output NDArray, will use default context if not specified.
* @param out Output place holder
* @return The result NDArray with generated result.
*/
def normal(mean: Float, stdvar: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = {
var outCopy = out
if (outCopy != null) {
require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.")
} else {
require(shape != null, "shape is required when out is not specified")
outCopy = empty(shape, ctx)
}
return _randomGaussian(mean, stdvar, outCopy)
}


/**
* Seed the random number generators in mxnet.
*
* This seed will affect behavior of functions in this module,
* as well as results from executors that contains Random number
* such as Dropout operators.
*
* @param seedState The random number seed to set to all devices.
* @note The random number generator of mxnet is by default device specific.
* This means if you set the same seed, the random number sequence
* generated from GPU0 can be different from CPU.
*/
def seed(seedState: Int) = {
// TODO
// checkCall(_LIB.mxRandomSeed(seedState))
}
}

0 comments on commit aea7c13

Please sign in to comment.