Skip to content

Commit

Permalink
Merge pull request apache#19 from javelinjs/scala-package-ndarray-fix
Browse files Browse the repository at this point in the history
NDArray handle ptr change to pure long
  • Loading branch information
yanqingmen committed Jan 6, 2016
2 parents 1ab9104 + 9430f20 commit 722c7a8
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 117 deletions.
6 changes: 4 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ object Base {
type CPtrAddress = Long

type SymbolHandle = CPtrAddress
type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress

type MXUintRef = RefInt
type MXFloatRef = RefFloat
type NDArrayHandle = RefLong
type FunctionHandle = RefLong
type NDArrayHandleRef = RefLong
type FunctionHandleRef = RefLong
type DataIterHandle = RefLong
type DataIterCreator = RefLong
type KVStoreHandle = RefLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
*/
def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
val ndArrayPtrs = outGrads.map(_.handle.value)
val ndArrayPtrs = outGrads.map(_.handle)
checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs))
}

Expand Down
8 changes: 4 additions & 4 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,19 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter {
* @return the data of current batch
*/
override def getData(): NDArray = {
val out = new NDArrayHandle
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetData(handle, out))
new NDArray(out, writable = false)
new NDArray(out.value, writable = false)
}

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): NDArray = {
val out = new NDArrayHandle
val out = new NDArrayHandleRef
checkCall(_LIB.mxDataIterGetLabel(handle, out))
new NDArray(out, writable = false)
new NDArray(out.value, writable = false)
}

/**
Expand Down
6 changes: 3 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class KVStore(private val handle: KVStoreHandle) {
*/
def init(keys: Array[Int], values: Array[NDArray]): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle.value)
val valuePtrs = values.map(_.handle)
checkCall(_LIB.mxKVStoreInit(handle, keys.length, keys, valuePtrs))
}

Expand All @@ -61,7 +61,7 @@ class KVStore(private val handle: KVStoreHandle) {
*/
def push(keys: Array[Int], values: Array[NDArray], priority: Int): Unit = {
require(keys.length == values.length, "len(keys) != len(values)")
val valuePtrs = values.map(_.handle.value)
val valuePtrs = values.map(_.handle)
checkCall(_LIB.mxKVStorePush(handle, keys.length, keys, valuePtrs, priority))
}

Expand Down Expand Up @@ -97,7 +97,7 @@ class KVStore(private val handle: KVStoreHandle) {
*/
def pull(keys: Array[Int], outs: Array[NDArray], priority: Int): Unit = {
require(keys.length == outs.length, "len(keys) != len(outs)")
val outPtrs = outs.map(_.handle.value)
val outPtrs = outs.map(_.handle)
checkCall(_LIB.mxKVStorePull(handle, keys.length, keys, outPtrs, priority))
}

Expand Down
34 changes: 11 additions & 23 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
class LibInfo {
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxGetLastError(): String
@native def mxNDArrayCreateNone(out: NDArrayHandle): Int
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
@native def mxNDArrayCreate(shape: Array[Int],
ndim: Int,
devType: Int,
devId: Int,
delayAlloc: Int,
out: NDArrayHandle): Int
out: NDArrayHandleRef): Int
@native def mxNDArrayWaitAll(): Int
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
@native def mxFuncDescribe(handle: FunctionHandle,
Expand All @@ -33,13 +33,9 @@ class LibInfo {
argTypes: ListBuffer[String],
argDescs: ListBuffer[String]): Int
@native def mxFuncInvoke(function: FunctionHandle,
// useVars ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
useVars: Array[CPtrAddress],
useVars: Array[NDArrayHandle],
scalarArgs: Array[MXFloat],
// mutateVars ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
mutateVars: Array[CPtrAddress]): Int
mutateVars: Array[NDArrayHandle]): Int
@native def mxNDArrayGetShape(handle: NDArrayHandle,
ndim: MXUintRef,
data: ArrayBuffer[Int]): Int
Expand All @@ -49,30 +45,24 @@ class LibInfo {
@native def mxNDArraySlice(handle: NDArrayHandle,
start: MXUint,
end: MXUint,
sliceHandle: NDArrayHandle): Int
sliceHandle: NDArrayHandleRef): Int
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
@native def mxKVStoreCreate(name: String, handle: KVStoreHandle): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
// values ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
values: Array[CPtrAddress]): Int
values: Array[NDArrayHandle]): Int
@native def mxKVStorePush(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
// values ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
values: Array[CPtrAddress],
values: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStorePull(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
// outs ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
outs: Array[CPtrAddress],
outs: Array[NDArrayHandle],
priority: Int): Int
@native def mxKVStoreSetUpdater(handle: KVStoreHandle,
updaterFunc: MXKVStoreUpdater,
Expand Down Expand Up @@ -101,9 +91,9 @@ class LibInfo {
@native def mxDataIterBeforeFirst(handle: DataIterHandle): Int
@native def mxDataIterNext(handle: DataIterHandle, out: RefInt): Int
@native def mxDataIterGetLabel(handle: DataIterHandle,
out: NDArrayHandle): Int
out: NDArrayHandleRef): Int
@native def mxDataIterGetData(handle: DataIterHandle,
out: NDArrayHandle): Int
out: NDArrayHandleRef): Int
@native def mxDataIterGetIndex(handle: DataIterHandle,
outIndex: ListBuffer[Long],
outSize: RefLong): Int
Expand All @@ -115,9 +105,7 @@ class LibInfo {
@native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int
@native def mxExecutorBackward(handle: ExecutorHandle,
gradsSize: Int,
// grads ought to be Array[NDArrayHandle],
// we pass ptr address directly for performance consideration
grads: Array[CPtrAddress]): Int
grads: Array[NDArrayHandle]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int

Expand Down
28 changes: 14 additions & 14 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ object NDArray {
require(function != null, s"invalid function name $funcName")
require(output == null || output.writable, "out must be writable")
function match {
case BinaryNDArrayFunction(handle: NDArrayHandle, acceptEmptyMutate: Boolean) =>
case BinaryNDArrayFunction(handle: FunctionHandle, acceptEmptyMutate: Boolean) =>
if (output == null) {
require(acceptEmptyMutate, s"argument out is required to call $funcName")
output = new NDArray(_newEmptyHandle())
}
checkCall(_LIB.mxFuncInvoke(handle,
Array(lhs.handle.value, rhs.handle.value),
Array(lhs.handle, rhs.handle),
Array[MXFloat](),
Array(output.handle.value)))
Array(output.handle)))
case _ => throw new RuntimeException(s"call $funcName as binary function")
}
output
Expand All @@ -53,9 +53,9 @@ object NDArray {
output = new NDArray(_newEmptyHandle())
}
checkCall(_LIB.mxFuncInvoke(handle,
Array(src.handle.value),
Array(src.handle),
Array[MXFloat](),
Array(output.handle.value)))
Array(output.handle)))
case _ => throw new RuntimeException(s"call $funcName as unary function")
}
output
Expand Down Expand Up @@ -88,9 +88,9 @@ object NDArray {
mutateVars = Array.fill[NDArray](nMutateVars)(new NDArray(_newEmptyHandle()))
}
checkCall(_LIB.mxFuncInvoke(handle,
useVarsRange.map(args(_).asInstanceOf[NDArray].handle.value).toArray,
useVarsRange.map(args(_).asInstanceOf[NDArray].handle).toArray,
scalarRange.map(args(_).asInstanceOf[MXFloat]).toArray,
mutateVars.map(_.handle.value).array))
mutateVars.map(_.handle).array))
case _ => throw new RuntimeException(s"call $funcName as generic function")
}
mutateVars
Expand All @@ -103,9 +103,9 @@ object NDArray {
* @return a new empty ndarray handle
*/
private def _newEmptyHandle(): NDArrayHandle = {
val hdl: NDArrayHandle = new NDArrayHandle
val hdl = new NDArrayHandleRef
checkCall(_LIB.mxNDArrayCreateNone(hdl))
hdl
hdl.value
}

/**
Expand All @@ -117,15 +117,15 @@ object NDArray {
private def _newAllocHandle(shape: Array[Int],
ctx: Context,
delayAlloc: Boolean): NDArrayHandle = {
val hdl = new NDArrayHandle
val hdl = new NDArrayHandleRef
checkCall(_LIB.mxNDArrayCreate(
shape,
shape.length,
ctx.deviceTypeid,
ctx.deviceId,
if (delayAlloc) 1 else 0,
hdl))
hdl
hdl.value
}

/**
Expand Down Expand Up @@ -332,7 +332,7 @@ object NDArray {
* NDArray is basic ndarray/Tensor like data structure in mxnet.
*/
// scalastyle:off finalize
class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true) {
override def finalize(): Unit = {
checkCall(_LIB.mxNDArrayFree(handle))
}
Expand All @@ -356,9 +356,9 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
* @return a sliced NDArray that shares memory with current one.
*/
def slice(start: Int, stop: Int): NDArray = {
val sliceHandle = new NDArrayHandle()
val sliceHandle = new NDArrayHandleRef
checkCall(_LIB.mxNDArraySlice(handle, start, stop, sliceHandle))
new NDArray(handle = sliceHandle, writable = this.writable)
new NDArray(handle = sliceHandle.value, writable = this.writable)
}

def slice(start: Int): NDArray = {
Expand Down
Loading

0 comments on commit 722c7a8

Please sign in to comment.