diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala index 35e05786fe74..756194b94566 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LRScheduler.scala @@ -51,4 +51,3 @@ class FactorScheduler(protected var step: Int, protected var factor: Float) exte this.baseLR } } - diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 46dd88c062f6..2592394b8550 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -41,12 +41,15 @@ class LibInfo { ndim: MXUintRef, data: ArrayBuffer[Int]): Int @native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle, - data: Array[Float], + data: Array[MXFloat], size: Int): Int @native def mxNDArraySlice(handle: NDArrayHandle, start: MXUint, end: MXUint, sliceHandle: NDArrayHandle): Int + @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, + source: Array[MXFloat], + size: Int): Int @native def mxKVStoreCreate(name: String, handle: KVStoreHandle): Int @native def mxKVStoreInit(handle: KVStoreHandle, len: MXUint, diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 9b8a55b99c77..36a46064e4e6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -201,6 +201,10 @@ object NDArray { new NDArray(handle = NDArray._newAllocHandle(shape, context, delayAlloc = false)) } + def empty(shape: Int *): NDArray = empty(shape.toArray) + + def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toArray, ctx) + /** * Create a new NDArray filled with 0, with specified shape. * @@ -211,10 +215,14 @@ object NDArray { */ def zeros(shape: Array[Int], ctx: Context=null): NDArray = { val arr = empty(shape, ctx) - arr(0).set(0f) + arr.set(0f) arr } + def zeros(shape: Int *): NDArray = zeros(shape.toArray) + + def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toArray, ctx) + /** * Create a new NDArray filled with 1, with specified shape. * @param shape shape of the NDArray. @@ -223,10 +231,25 @@ object NDArray { */ def ones(shape: Array[Int], ctx: Context=null): NDArray = { val arr = empty(shape, ctx) - arr(0).set(1f) + arr.set(1f) arr } + def ones(shape: Int *): NDArray = ones(shape.toArray) + + def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toArray, ctx) + + /** + * Clip ndarray elements to range (from, to) + * @param array ndarray to be clipped + * @param min array min elements + * @param max array max elements + * @return a new clipped [[NDArray]] + */ + def clip(array: NDArray, min: Float, max: Float): NDArray = { + NDArray._genericNDArrayFunction("clip", Array(array, min, max))(0) + } + /** * Create a new NDArray that copies content from source_array. * @param sourceArr Source data to create NDArray from. @@ -285,7 +308,10 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { * Peform an synchronize copy from the array. * @param source The data source we should like to copy from. */ - def _syncCopyfrom(source: Array[Float]): Unit = ??? + private def syncCopyfrom(source: Array[Float]): Unit = { + require(source.length == size, "array size do not match the size of NDArray") + checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length)) + } /** * Return a sliced NDArray that shares memory with current one. @@ -296,14 +322,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { * * @return a sliced NDArray that shares memory with current one. */ - private def _slice(start: Int, stop: Int): NDArray = { + def slice(start: Int, stop: Int): NDArray = { val sliceHandle = new NDArrayHandle() checkCall(_LIB.mxNDArraySlice(handle, start, stop, sliceHandle)) new NDArray(handle = sliceHandle, writable = this.writable) } - private def _slice(start: Int): NDArray = { - _slice(start, shape(0)) + def slice(start: Int): NDArray = { + slice(start, shape(0)) } /** @@ -314,9 +340,6 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { */ def waitToRead(): Unit = ??? - def apply(sliceStart: Int): NDArray = _slice(sliceStart) - def apply(sliceStart: Int, sliceEnd: Int): NDArray = _slice(sliceStart, sliceEnd) - /** * Get context of current NDArray. * @return The context of current NDArray. @@ -334,10 +357,17 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { this } - def set(other: NDArray) = { + def set(other: NDArray): NDArray = { + require(writable, "trying to assign to a readonly NDArray") other.copyTo(this) } + def set(other: Array[Float]): NDArray = { + require(writable, "trying to assign to a readonly NDArray") + syncCopyfrom(other) + this + } + def +(other: NDArray): NDArray = { NDArray._binaryNDArrayFunction("_plus", this, other) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index 0d0cd38d6638..f9a58f5ca4db 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -1,9 +1,11 @@ package ml.dmlc.mxnet +import scala.collection.mutable + object Optimizer { def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { new MXKVStoreUpdater { - private val states = new scala.collection.mutable.HashMap[Int, AnyRef] + val states = new scala.collection.mutable.HashMap[Int, AnyRef] override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = { val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) optimizer.update(index, weight, grad, state) @@ -12,7 +14,11 @@ object Optimizer { } } -abstract class Optimizer extends Serializable { +abstract class Optimizer(protected var rescaleGrad: Float = 1f) extends Serializable { + protected var lrScale: mutable.Map[Int, Float] = mutable.HashMap.empty[Int, Float] + protected var numUpdate: Int = 0 + protected val indexUpdateCount: mutable.Map[Int, Int] = mutable.HashMap.empty[Int, Int] + /** * Update the parameters. * @param index An unique integer key used to index the parameters @@ -21,10 +27,27 @@ abstract class Optimizer extends Serializable { * @param state NDArray or other objects returned by initState * The auxiliary state used in optimization. */ - def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = ??? + // TODO: make state a ClassTag + def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit // Create additional optimizer state such as momentum. + // TODO: make returned state a ClassTag def createState(index: Int, weight: NDArray): AnyRef + + // Set individual learning rate scale for parameters + def setLrScale(lrScale: Map[Int, Float]) { + this.lrScale = mutable.Map(lrScale.toSeq: _*) + } + + /** + * update num_update + * @param index The index will be updated + */ + protected def updateCount(index: Int): Unit = { + val count = indexUpdateCount.getOrElseUpdate(index, 0) + 1 + indexUpdateCount.update(index, count) + numUpdate = Math.max(count, numUpdate) + } } trait MXKVStoreUpdater { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala new file mode 100644 index 000000000000..bd434dab8ae2 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/optimizer/SGD.scala @@ -0,0 +1,55 @@ +package ml.dmlc.mxnet.optimizer + +import ml.dmlc.mxnet.{Optimizer, LRScheduler, NDArray} +import ml.dmlc.mxnet.NDArrayConversions._ + +/** + * A very simple SGD optimizer with momentum and weight regularization. + * @author Yizhi Liu + */ +class SGD(val learningRate: Float = 0.01f, val momentum: Float = 0.0f, + val wd: Float = 0.0001f, rescaleGrad: Float = 1f, val clipGradient: Float = 0f, + val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) { + /** + * Update the parameters. + * @param index An unique integer key used to index the parameters + * @param weight weight ndarray + * @param grad grad ndarray + * @param state NDArray or other objects returned by initState + * The auxiliary state used in optimization. + */ + override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = { + // TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python package) + val lr = + (if (lrScheduler != null) { + val scheduledLr = lrScheduler(numUpdate) + updateCount(index) + scheduledLr + } else { + this.learningRate + }) * lrScale.getOrElse(index, 1f) + + var resdGrad = grad * rescaleGrad + if (clipGradient != 0f) { + resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient) + } + if (state != null) { + val mom = state.asInstanceOf[NDArray] + mom *= momentum + mom += -lr * (grad + wd * weight) + weight += mom + } else { + require(momentum == 0f) + weight += -lr * (grad + wd * weight) + } + } + + // Create additional optimizer state such as momentum. + override def createState(index: Int, weight: NDArray): AnyRef = { + if (momentum == 0.0f) { + null + } else { + NDArray.zeros(weight.shape, weight.context) + } + } +} diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index 4dfa62e19621..840a381fbd87 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -5,29 +5,41 @@ import ml.dmlc.mxnet.NDArrayConversions._ class NDArraySuite extends FunSuite with BeforeAndAfterAll { test("to java array") { - val ndarray = NDArray.zeros(Array(2, 2)) + val ndarray = NDArray.zeros(2, 2) assert(ndarray.toArray === Array(0f, 0f, 0f, 0f)) } test("to scalar") { - val ndzeros = NDArray.zeros(Array(1)) + val ndzeros = NDArray.zeros(1) assert(ndzeros.toScalar === 0f) - val ndones = NDArray.ones(Array(1)) + val ndones = NDArray.ones(1) assert(ndones.toScalar === 1f) } test ("call toScalar on an ndarray which is not a scalar") { - intercept[Exception] { NDArray.zeros(Array(1,1)).toScalar } + intercept[Exception] { NDArray.zeros(1, 1).toScalar } } test("size and shape") { - val ndzeros = NDArray.zeros(Array(4, 1)) + val ndzeros = NDArray.zeros(4, 1) assert(ndzeros.shape === Array(4, 1)) assert(ndzeros.size === 4) } + test("set scalar value") { + val ndarray = NDArray.empty(2, 1) + ndarray.set(10f) + assert(ndarray.toArray === Array(10f, 10f)) + } + + test("copy from java array") { + val ndarray = NDArray.empty(4, 1) + ndarray.set(Array(1f, 2f, 3f, 4f)) + assert(ndarray.toArray === Array(1f, 2f, 3f, 4f)) + } + test("plus") { - val ndzeros = NDArray.zeros(Array(2, 1)) + val ndzeros = NDArray.zeros(2, 1) val ndones = ndzeros + 1f assert(ndones.toArray === Array(1f, 1f)) assert((ndones + ndzeros).toArray === Array(1f, 1f)) @@ -38,7 +50,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("minus") { - val ndones = NDArray.ones(Array(2, 1)) + val ndones = NDArray.ones(2, 1) val ndzeros = ndones - 1f assert(ndzeros.toArray === Array(0f, 0f)) assert((ndones - ndzeros).toArray === Array(1f, 1f)) @@ -50,7 +62,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("multiplication") { - val ndones = NDArray.ones(Array(2, 1)) + val ndones = NDArray.ones(2, 1) val ndtwos = ndones * 2 assert(ndtwos.toArray === Array(2f, 2f)) assert((ndones * ndones).toArray === Array(1f, 1f)) @@ -61,7 +73,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("division") { - val ndones = NDArray.ones(Array(2, 1)) + val ndones = NDArray.ones(2, 1) val ndzeros = ndones - 1f val ndhalves = ndones / 2 assert(ndhalves.toArray === Array(0.5f, 0.5f)) @@ -73,4 +85,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { assert(ndhalves.toArray === Array(1f, 1f)) } + test("clip") { + val ndarray = NDArray.empty(3, 2) + ndarray.set(Array(1f, 2f, 3f, 4f, 5f, 6f)) + assert(NDArray.clip(ndarray, 2f, 5f).toArray === Array(2f, 2f, 3f, 4f, 5f, 5f)) + } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 4e6b1461c945..e1bb66d45858 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -204,6 +204,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySlice(JNIEnv *env, jo return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU + (JNIEnv *env, jobject obj, jobject ndArrayHandle, jfloatArray sourceArr, jint arrSize) { + jlong arrayPtr = getLongField(env, ndArrayHandle); + jfloat *sourcePtr = env->GetFloatArrayElements(sourceArr, NULL); + int ret = MXNDArraySyncCopyFromCPU((NDArrayHandle)arrayPtr, (const mx_float *)sourcePtr, arrSize); + env->ReleaseFloatArrayElements(sourceArr, sourcePtr, 0); + return ret; +} + // The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, // while we write java functions here in scala-package. // Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked,