diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala new file mode 100644 index 000000000000..5287200065e2 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala @@ -0,0 +1,133 @@ +package ml.dmlc.mxnet + +import ml.dmlc.mxnet.NDArray.{array, zeros, ones} + + +/** + * + * Base class for Initializer. + * + * @author Yuan Tang + */ +abstract class Initializer { + + /** + * Initialize an Initializer + * + * @param name name of corrosponding ndarray + * @param arr ndarray to be Initialized + */ + def apply(name: String, arr: NDArray): Unit = { + + if (name.startsWith("upsampling")) { + _initBilinear(name, arr) + } else if (name.endsWith("bias")) { + _initBias(name, arr) + } else if (name.endsWith("gamma")) { + _initGamma(name, arr) + } else if (name.endsWith("beta")) { + _initBeta(name, arr) + } else if (name.endsWith("weight")) { + _initWeight(name, arr) + } else if (name.endsWith("moving_mean")) { + _initZero(name, arr) + } else if (name.endsWith("moving_var")) { + _initZero(name, arr) + } else if (name.endsWith("moving_avg")) { + _initZero(name, arr) + } else { + throw new IllegalArgumentException(s"Unknown initialization pattern for ${name}.") + } + } + + def _initBilinear(name: String, arr: NDArray): Unit = { + val weight = Array.fill[Float](arr.size)(0.0f) + val shape = arr.shape + val f = shape(3) / 2.0f + val c = (2 * f - 1 - f % 2) / (2.0f * f) + + (0 to (arr.size)).foreach { i => + val x = i % shape(3) + val y = (i / shape(3)) % shape(2) + weight(i) = (1 - math.abs(x / f - c)) * (1 - math.abs(y / f - c)) + } + + arr.set(array(weight)) + } + + def _initZero(name: String, arr: NDArray): Unit = { + arr.set(0f) + } + + def _initBias(name: String, arr: NDArray): Unit = { + arr.set(0f) + } + + def _initGamma(name: String, arr: NDArray): Unit = { + arr.set(1f) + } + + def _initBeta(name: String, arr: NDArray): Unit = { + arr.set(0f) + } + + def _initWeight(name: String, arr: NDArray): Unit +} + + +/** + * Initialize the weight with uniform [-scale, scale] + * + * @param scale The scale of uniform distribution + */ +class Uniform(protected val scale: Float=0.07f) extends Initializer { + override def _initWeight(name: String, arr: NDArray): Unit = { + Random.uniform(-scale, scale, out=arr) + } +} + + +/** + * Initialize the weight with normal(0, sigma) + * + * @param sigma Standard deviation for gaussian distribution. + */ +class Normal(protected val sigma: Float=0.01f) extends Initializer { + override def _initWeight(name: String, arr: NDArray): Unit = { + Random.normal(0, sigma, out=arr) + } +} + + +/** + * Initialize the weight with Xavier or similar initialization scheme. + * + * @param rndType Options are: "gaussian" or "uniform" + * @param factorType Options are: "avg", "in", "out" + * @param magnitude scale of random number range + */ +class Xavier(protected val rndType: String ="uniform", + protected val factorType: String ="avg", + protected val magnitude: Int = 3) extends Initializer { + + override def _initWeight(name: String, arr: NDArray): Unit = { + val shape = arr.shape + val fanIn = shape.slice(1, shape.length).product + val fanOut = shape(0) + var factor = 1 + + factor = factorType match { + case "avg" => (fanIn + fanOut) / 2 + case "in" => fanIn + case "out" => fanOut + case _ => throw new IllegalArgumentException("Incorrect factor type") + } + val scale = math.sqrt(magnitude / factor).toFloat + + rndType match { + case "uniform" => Random.uniform(-scale, scale, out=arr) + case "normal" => Random.normal(0, scale, out=arr) + case _ => throw new IllegalArgumentException("Unknown random type") + } + } +} \ No newline at end of file diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 53b2145bc29a..0429313b6db1 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -268,13 +268,19 @@ object NDArray { NDArray._unaryNDArrayFunction("norm", src) } + // TODO + def _randomUniform(low: Float, high: Float, out: NDArray) = ??? + + def _randomGaussian(mean: Float, stdvar: Float, out: NDArray) = ??? + + /** * Create a new NDArray that copies content from source_array. * @param sourceArr Source data to create NDArray from. * @param ctx The context of the NDArray, default to current default context. * @return The created NDArray. */ - def array(sourceArr: Array[Int], ctx: Context=null): NDArray = ??? + def array(sourceArr: Array[Float], ctx: Context=null): NDArray = ??? /** * Load ndarray from binary file. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala new file mode 100644 index 000000000000..9bcfe4eb162b --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala @@ -0,0 +1,71 @@ +package ml.dmlc.mxnet + +import ml.dmlc.mxnet.Base._ +import ml.dmlc.mxnet.NDArray.{_randomUniform, _randomGaussian, empty} + +/** + * Random Number interface of mxnet. + * @author Yuan Tang + */ +object Random { + /** + * Generate uniform distribution in [low, high) with shape. + * + * @param low The lower bound of distribution. + * @param high The upper bound of distribution. + * @param shape Output shape of the NDArray generated. + * @param ctx Context of output NDArray, will use default context if not specified. + * @param out Output place holder + * @return The result NDArray with generated result. + */ + def uniform(low: Float, high: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = { + var outCopy = out + if (outCopy != null) { + require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.") + } else { + require(shape != null, "shape is required when out is not specified") + outCopy = empty(shape, ctx) + } + return _randomUniform(low, high, outCopy) + } + + + /** + * Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape. + * + * @param mean The mean of the normal distribution. + * @param stdvar The standard deviation of normal distribution. + * @param shape Output shape of the NDArray generated. + * @param ctx Context of output NDArray, will use default context if not specified. + * @param out Output place holder + * @return The result NDArray with generated result. + */ + def normal(mean: Float, stdvar: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = { + var outCopy = out + if (outCopy != null) { + require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.") + } else { + require(shape != null, "shape is required when out is not specified") + outCopy = empty(shape, ctx) + } + return _randomGaussian(mean, stdvar, outCopy) + } + + + /** + * Seed the random number generators in mxnet. + * + * This seed will affect behavior of functions in this module, + * as well as results from executors that contains Random number + * such as Dropout operators. + * + * @param seedState The random number seed to set to all devices. + * @note The random number generator of mxnet is by default device specific. + * This means if you set the same seed, the random number sequence + * generated from GPU0 can be different from CPU. + */ + def seed(seedState: Int) = { + // TODO +// checkCall(_LIB.mxRandomSeed(seedState)) + } +}