Skip to content

Commit

Permalink
Merge pull request apache#8 from javelinjs/scala-package-cc
Browse files Browse the repository at this point in the history
NDArray functions and SGD optimizer
  • Loading branch information
terrytangyuan committed Dec 22, 2015
2 parents a6c9ead + c14ed50 commit 67abdee
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,3 @@ class FactorScheduler(protected var step: Int, protected var factor: Float) exte
this.baseLR
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ class LibInfo {
ndim: MXUintRef,
data: ArrayBuffer[Int]): Int
@native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
data: Array[Float],
data: Array[MXFloat],
size: Int): Int
@native def mxNDArraySlice(handle: NDArrayHandle,
start: MXUint,
end: MXUint,
sliceHandle: NDArrayHandle): Int
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
@native def mxKVStoreCreate(name: String, handle: KVStoreHandle): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
len: MXUint,
Expand Down
50 changes: 40 additions & 10 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ object NDArray {
new NDArray(handle = NDArray._newAllocHandle(shape, context, delayAlloc = false))
}

def empty(shape: Int *): NDArray = empty(shape.toArray)

def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toArray, ctx)

/**
* Create a new NDArray filled with 0, with specified shape.
*
Expand All @@ -211,10 +215,14 @@ object NDArray {
*/
def zeros(shape: Array[Int], ctx: Context=null): NDArray = {
val arr = empty(shape, ctx)
arr(0).set(0f)
arr.set(0f)
arr
}

def zeros(shape: Int *): NDArray = zeros(shape.toArray)

def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toArray, ctx)

/**
* Create a new NDArray filled with 1, with specified shape.
* @param shape shape of the NDArray.
Expand All @@ -223,10 +231,25 @@ object NDArray {
*/
def ones(shape: Array[Int], ctx: Context=null): NDArray = {
val arr = empty(shape, ctx)
arr(0).set(1f)
arr.set(1f)
arr
}

def ones(shape: Int *): NDArray = ones(shape.toArray)

def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toArray, ctx)

/**
* Clip ndarray elements to range (from, to)
* @param array ndarray to be clipped
* @param min array min elements
* @param max array max elements
* @return a new clipped [[NDArray]]
*/
def clip(array: NDArray, min: Float, max: Float): NDArray = {
NDArray._genericNDArrayFunction("clip", Array(array, min, max))(0)
}

/**
* Create a new NDArray that copies content from source_array.
* @param sourceArr Source data to create NDArray from.
Expand Down Expand Up @@ -285,7 +308,10 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
* Peform an synchronize copy from the array.
* @param source The data source we should like to copy from.
*/
def _syncCopyfrom(source: Array[Float]): Unit = ???
private def syncCopyfrom(source: Array[Float]): Unit = {
require(source.length == size, "array size do not match the size of NDArray")
checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length))
}

/**
* Return a sliced NDArray that shares memory with current one.
Expand All @@ -296,14 +322,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
*
* @return a sliced NDArray that shares memory with current one.
*/
private def _slice(start: Int, stop: Int): NDArray = {
def slice(start: Int, stop: Int): NDArray = {
val sliceHandle = new NDArrayHandle()
checkCall(_LIB.mxNDArraySlice(handle, start, stop, sliceHandle))
new NDArray(handle = sliceHandle, writable = this.writable)
}

private def _slice(start: Int): NDArray = {
_slice(start, shape(0))
def slice(start: Int): NDArray = {
slice(start, shape(0))
}

/**
Expand All @@ -314,9 +340,6 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
*/
def waitToRead(): Unit = ???

def apply(sliceStart: Int): NDArray = _slice(sliceStart)
def apply(sliceStart: Int, sliceEnd: Int): NDArray = _slice(sliceStart, sliceEnd)

/**
* Get context of current NDArray.
* @return The context of current NDArray.
Expand All @@ -334,10 +357,17 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
this
}

def set(other: NDArray) = {
def set(other: NDArray): NDArray = {
require(writable, "trying to assign to a readonly NDArray")
other.copyTo(this)
}

def set(other: Array[Float]): NDArray = {
require(writable, "trying to assign to a readonly NDArray")
syncCopyfrom(other)
this
}

def +(other: NDArray): NDArray = {
NDArray._binaryNDArrayFunction("_plus", this, other)
}
Expand Down
29 changes: 26 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package ml.dmlc.mxnet

import scala.collection.mutable

object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater {
private val states = new scala.collection.mutable.HashMap[Int, AnyRef]
val states = new scala.collection.mutable.HashMap[Int, AnyRef]
override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = {
val state = states.getOrElseUpdate(index, optimizer.createState(index, weight))
optimizer.update(index, weight, grad, state)
Expand All @@ -12,7 +14,11 @@ object Optimizer {
}
}

abstract class Optimizer extends Serializable {
abstract class Optimizer(protected var rescaleGrad: Float = 1f) extends Serializable {
protected var lrScale: mutable.Map[Int, Float] = mutable.HashMap.empty[Int, Float]
protected var numUpdate: Int = 0
protected val indexUpdateCount: mutable.Map[Int, Int] = mutable.HashMap.empty[Int, Int]

/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
Expand All @@ -21,10 +27,27 @@ abstract class Optimizer extends Serializable {
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = ???
// TODO: make state a ClassTag
def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit

// Create additional optimizer state such as momentum.
// TODO: make returned state a ClassTag
def createState(index: Int, weight: NDArray): AnyRef

// Set individual learning rate scale for parameters
def setLrScale(lrScale: Map[Int, Float]) {
this.lrScale = mutable.Map(lrScale.toSeq: _*)
}

/**
* update num_update
* @param index The index will be updated
*/
protected def updateCount(index: Int): Unit = {
val count = indexUpdateCount.getOrElseUpdate(index, 0) + 1
indexUpdateCount.update(index, count)
numUpdate = Math.max(count, numUpdate)
}
}

trait MXKVStoreUpdater {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package ml.dmlc.mxnet.optimizer

import ml.dmlc.mxnet.{Optimizer, LRScheduler, NDArray}
import ml.dmlc.mxnet.NDArrayConversions._

/**
* A very simple SGD optimizer with momentum and weight regularization.
* @author Yizhi Liu
*/
class SGD(val learningRate: Float = 0.01f, val momentum: Float = 0.0f,
val wd: Float = 0.0001f, rescaleGrad: Float = 1f, val clipGradient: Float = 0f,
val lrScheduler: LRScheduler = null) extends Optimizer(rescaleGrad: Float) {
/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
* @param weight weight ndarray
* @param grad grad ndarray
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
// TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python package)
val lr =
(if (lrScheduler != null) {
val scheduledLr = lrScheduler(numUpdate)
updateCount(index)
scheduledLr
} else {
this.learningRate
}) * lrScale.getOrElse(index, 1f)

var resdGrad = grad * rescaleGrad
if (clipGradient != 0f) {
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
}
if (state != null) {
val mom = state.asInstanceOf[NDArray]
mom *= momentum
mom += -lr * (grad + wd * weight)
weight += mom
} else {
require(momentum == 0f)
weight += -lr * (grad + wd * weight)
}
}

// Create additional optimizer state such as momentum.
override def createState(index: Int, weight: NDArray): AnyRef = {
if (momentum == 0.0f) {
null
} else {
NDArray.zeros(weight.shape, weight.context)
}
}
}
35 changes: 26 additions & 9 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,41 @@ import ml.dmlc.mxnet.NDArrayConversions._

class NDArraySuite extends FunSuite with BeforeAndAfterAll {
test("to java array") {
val ndarray = NDArray.zeros(Array(2, 2))
val ndarray = NDArray.zeros(2, 2)
assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
}

test("to scalar") {
val ndzeros = NDArray.zeros(Array(1))
val ndzeros = NDArray.zeros(1)
assert(ndzeros.toScalar === 0f)
val ndones = NDArray.ones(Array(1))
val ndones = NDArray.ones(1)
assert(ndones.toScalar === 1f)
}

test ("call toScalar on an ndarray which is not a scalar") {
intercept[Exception] { NDArray.zeros(Array(1,1)).toScalar }
intercept[Exception] { NDArray.zeros(1, 1).toScalar }
}

test("size and shape") {
val ndzeros = NDArray.zeros(Array(4, 1))
val ndzeros = NDArray.zeros(4, 1)
assert(ndzeros.shape === Array(4, 1))
assert(ndzeros.size === 4)
}

test("set scalar value") {
val ndarray = NDArray.empty(2, 1)
ndarray.set(10f)
assert(ndarray.toArray === Array(10f, 10f))
}

test("copy from java array") {
val ndarray = NDArray.empty(4, 1)
ndarray.set(Array(1f, 2f, 3f, 4f))
assert(ndarray.toArray === Array(1f, 2f, 3f, 4f))
}

test("plus") {
val ndzeros = NDArray.zeros(Array(2, 1))
val ndzeros = NDArray.zeros(2, 1)
val ndones = ndzeros + 1f
assert(ndones.toArray === Array(1f, 1f))
assert((ndones + ndzeros).toArray === Array(1f, 1f))
Expand All @@ -38,7 +50,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}

test("minus") {
val ndones = NDArray.ones(Array(2, 1))
val ndones = NDArray.ones(2, 1)
val ndzeros = ndones - 1f
assert(ndzeros.toArray === Array(0f, 0f))
assert((ndones - ndzeros).toArray === Array(1f, 1f))
Expand All @@ -50,7 +62,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}

test("multiplication") {
val ndones = NDArray.ones(Array(2, 1))
val ndones = NDArray.ones(2, 1)
val ndtwos = ndones * 2
assert(ndtwos.toArray === Array(2f, 2f))
assert((ndones * ndones).toArray === Array(1f, 1f))
Expand All @@ -61,7 +73,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}

test("division") {
val ndones = NDArray.ones(Array(2, 1))
val ndones = NDArray.ones(2, 1)
val ndzeros = ndones - 1f
val ndhalves = ndones / 2
assert(ndhalves.toArray === Array(0.5f, 0.5f))
Expand All @@ -73,4 +85,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
assert(ndhalves.toArray === Array(1f, 1f))
}

test("clip") {
val ndarray = NDArray.empty(3, 2)
ndarray.set(Array(1f, 2f, 3f, 4f, 5f, 6f))
assert(NDArray.clip(ndarray, 2f, 5f).toArray === Array(2f, 2f, 3f, 4f, 5f, 5f))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySlice(JNIEnv *env, jo
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
(JNIEnv *env, jobject obj, jobject ndArrayHandle, jfloatArray sourceArr, jint arrSize) {
jlong arrayPtr = getLongField(env, ndArrayHandle);
jfloat *sourcePtr = env->GetFloatArrayElements(sourceArr, NULL);
int ret = MXNDArraySyncCopyFromCPU((NDArrayHandle)arrayPtr, (const mx_float *)sourcePtr, arrSize);
env->ReleaseFloatArrayElements(sourceArr, sourcePtr, 0);
return ret;
}

// The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter,
// while we write java functions here in scala-package.
// Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked,
Expand Down

0 comments on commit 67abdee

Please sign in to comment.