Skip to content

Commit

Permalink
Merge pull request apache#29 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
LGTM
  • Loading branch information
yanqingmen committed Feb 9, 2016
2 parents 77e79c9 + 9cf857b commit 67dc44c
Show file tree
Hide file tree
Showing 17 changed files with 406 additions and 77 deletions.
4 changes: 4 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,9 @@
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
</dependency>
</dependencies>
</project>
2 changes: 1 addition & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ object Base {
type MXFloat = Float
type CPtrAddress = Long
// TODO: make it more friendly to java
type Shape = Array[Int]
type Shape = Vector[Int]

type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
Expand Down
14 changes: 13 additions & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package ml.dmlc.mxnet
object Context {
val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)
val defaultCtx = new Context("cpu", 0)
private var _defaultCtx = new Context("cpu", 0)

def defaultCtx: Context = _defaultCtx

def cpu(deviceId: Int = 0): Context = {
new Context("cpu", deviceId)
Expand All @@ -13,6 +15,16 @@ object Context {
new Context("gpu", deviceId)
}

def withScope[T](device: Context)(body: => T): T = {
val oldDefaultCtx = Context.defaultCtx
Context._defaultCtx = device
try {
body
} finally {
Context._defaultCtx = oldDefaultCtx
}
}

implicit def ctx2Array(ctx: Context): Array[Context] = Array(ctx)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class DataParallelExecutorManager(symbol: Symbol,
ctx.zipWithIndex.map { case (context, i) =>
val dataShapes =
trainData.provideData.map { case (name: String, shape: Shape) =>
(name, Array(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
(name, Vector(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
}
symbol.simpleBind(context, "write", shapeDict = dataShapes)
}
Expand Down Expand Up @@ -334,7 +334,7 @@ class DataParallelExecutorManager(symbol: Symbol,
}.toArray
private val batchSize = trainData.batchSize
private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray =>
Array(batchSize) ++ x.shape.drop(1)
Vector(batchSize) ++ x.shape.drop(1)
}
private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_))

Expand Down
9 changes: 6 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ class LibInfo {
keys: Array[String],
argIndPtr: Array[MXUint],
argShapeData: Array[MXUint],
inShapeData: ListBuffer[Shape],
outShapeData: ListBuffer[Shape],
auxShapeData: ListBuffer[Shape],
inShapeData: ListBuffer[Array[Int]],
outShapeData: ListBuffer[Array[Int]],
auxShapeData: ListBuffer[Array[Int]],
complete: RefInt): Int
@native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int
@native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int
Expand All @@ -185,4 +185,7 @@ class LibInfo {
auxArgsHandle: Array[NDArrayHandle],
out: ExecutorHandleRef): Int
// scalastyle:on parameterNum

// Random
@native def mxRandomSeed(seed: Int): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
}
batch = data.next()
}
// TODO: we can use Symbol.concat to do the same thing. Can it be more efficient?
outputs.map(NDArray.concatenate(_))
}

Expand Down
63 changes: 39 additions & 24 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,12 @@ object NDArray {
*
* @return a new empty ndarray handle
*/
private def newAllocHandle(shape: Array[Int],
private def newAllocHandle(shape: Shape,
ctx: Context,
delayAlloc: Boolean): NDArrayHandle = {
val hdl = new NDArrayHandleRef
checkCall(_LIB.mxNDArrayCreate(
shape,
shape.toArray,
shape.length,
ctx.deviceTypeid,
ctx.deviceId,
Expand Down Expand Up @@ -217,14 +217,14 @@ object NDArray {
*
* @return The created NDArray.
*/
def empty(shape: Array[Int], ctx: Context = null): NDArray = {
def empty(shape: Shape, ctx: Context = null): NDArray = {
val context = if (ctx == null) Context.defaultCtx else ctx
new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false))
}

def empty(shape: Int *): NDArray = empty(shape.toArray)
def empty(shape: Int *): NDArray = empty(shape.toVector)

def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toArray, ctx)
def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toVector, ctx)

/**
* Create a new NDArray filled with 0, with specified shape.
Expand All @@ -234,31 +234,31 @@ object NDArray {
*
* @return The created NDArray.
*/
def zeros(shape: Array[Int], ctx: Context = null): NDArray = {
def zeros(shape: Shape, ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(0f)
arr
}

def zeros(shape: Int *): NDArray = zeros(shape.toArray)
def zeros(shape: Int *): NDArray = zeros(shape.toVector)

def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toArray, ctx)
def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toVector, ctx)

/**
* Create a new NDArray filled with 1, with specified shape.
* @param shape shape of the NDArray.
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def ones(shape: Array[Int], ctx: Context = null): NDArray = {
def ones(shape: Shape, ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(1f)
arr
}

def ones(shape: Int *): NDArray = ones(shape.toArray)
def ones(shape: Int *): NDArray = ones(shape.toVector)

def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toArray, ctx)
def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toVector, ctx)

/**
* Clip ndarray elements to range (from, to)
Expand Down Expand Up @@ -460,32 +460,29 @@ object NDArray {
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
def array(sourceArr: Array[Float], shape: Array[Int], ctx: Context = null): NDArray = {
def array(sourceArr: Array[Float], shape: Shape, ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(sourceArr)
arr
}

/**
* Join a sequence of arrays at the first axis
* Join a sequence of arrays at axis-0
* TODO: shall we make it native?
* @param arrays
*/
def concatenate(arrays: Seq[NDArray], ctx: Context = null): NDArray = {
require(arrays != null && arrays.size > 0, "arrays empty")
val array0 = arrays.head
val shape = Array.ofDim[Int](array0.shape.length)
array0.shape.copyToArray(shape)
var axis0 = shape(0)
val shapeRemain = shape.drop(1)
val shape = array0.shape.drop(1)
var axis0 = array0.shape(0)
arrays.drop(1).foreach { array =>
require(shapeRemain.sameElements(array.shape.drop(1)),
require(shape.sameElements(array.shape.drop(1)),
s"shape mismatch between (${array.shape.mkString(",")}) and (${shape.mkString(",")})")
axis0 += array.shape(0)
}

shape(0) = axis0
val output = NDArray.empty(shape, ctx)
val output = NDArray.empty(Vector(axis0) ++ shape, ctx)
axis0 = 0
arrays.foreach { array =>
output.slice(axis0, axis0 + array.shape(0)).set(array)
Expand Down Expand Up @@ -604,8 +601,14 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
slice(range._1, range._2)
}

def slice(start: Int): NDArray = {
slice(start, shape(0))
/**
* Return a sliced NDArray at the ith position of axis0
* NDArray only support continuous slicing on axis 0
* @param i
* @return a sliced NDArray that shares memory with current one.
*/
def slice(i: Int): NDArray = {
slice(i, i + 1)
}

/**
Expand Down Expand Up @@ -804,16 +807,28 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
* Get shape of current NDArray.
* @return an array representing shape of current ndarray
*/
def shape: Array[Int] = {
def shape: Shape = {
val ndim = new MXUintRef
val data = ArrayBuffer[Int]()
checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}")
data.toArray
data.toVector
}

// Get size of current NDArray.
def size: Int = shape.product

override def equals(o: Any): Boolean = o match {
case that: NDArray => {
that.shape == this.shape && that.toArray.sameElements(this.toArray)
}
case _ => false
}

override def hashCode: Int = {
// TODO: naive implementation
shape.hashCode + toArray.hashCode
}
}
// scalastyle:on finalize

Expand Down
5 changes: 2 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base.Shape
import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.NDArray.{randomGaussian, randomUniform, empty}

/**
Expand Down Expand Up @@ -73,7 +73,6 @@ object Random {
* generated from GPU0 can be different from CPU.
*/
def seed(seedState: Int): Unit = {
// TODO
// checkCall(_LIB.mxRandomSeed(seedState))
checkCall(_LIB.mxRandomSeed(seedState))
}
}
60 changes: 51 additions & 9 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,15 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {

def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int])
: (Seq[Shape], Seq[Shape], Seq[Shape]) = {
val argShapeData = ListBuffer.empty[Shape]
val outShapeData = ListBuffer.empty[Shape]
val auxShapeData = ListBuffer.empty[Shape]
val argShapeData = ListBuffer.empty[Array[Int]]
val outShapeData = ListBuffer.empty[Array[Int]]
val auxShapeData = ListBuffer.empty[Array[Int]]
val complete = new RefInt

checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values,
argShapeData, outShapeData, auxShapeData, complete))
if (complete.value != 0) {
(argShapeData, outShapeData, auxShapeData)
(argShapeData.map(_.toVector), outShapeData.map(_.toVector), auxShapeData.map(_.toVector))
} else {
(null, null, null)
}
Expand Down Expand Up @@ -555,6 +555,10 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
bind(ctx, args, argsGrad, "write", Nil, null)
}

def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray]): Executor = {
bind(ctx, args, argsGrad, "write", Nil, null)
}

def bind(ctx: Context, args: Seq[NDArray]): Executor = {
val symbolArguments = listArguments()
bindHelper(ctx, symbolArguments, args, null,
Expand Down Expand Up @@ -670,6 +674,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
}

object Symbol {
private type SymbolCreateFunc = Map[String, Any] => Symbol
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
Expand All @@ -688,23 +693,23 @@ object Symbol {
sym
}

def FullyConnected: Map[String, Any] => Symbol = {
def FullyConnected: SymbolCreateFunc = {
FullyConnected(null)
}

def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = {
def FullyConnected(attr: Map[String, String]): SymbolCreateFunc = {
createNoCheck("FullyConnected", attr)
}

def Activation: Map[String, Any] => Symbol = {
def Activation: SymbolCreateFunc = {
Activation(null)
}

def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = {
def Activation(attr: Map[String, String]): SymbolCreateFunc = {
createNoCheck("Activation", attr)
}

def Convolution(attr: Map[String, String]): Map[String, Any] => Symbol = {
def Convolution(attr: Map[String, String]): SymbolCreateFunc = {
createNoCheck("Convolution", attr)
}

Expand Down Expand Up @@ -732,6 +737,43 @@ object Symbol {
createNoCheck("Cast")
}

def ElementWiseSum(name: String, inputs: Symbol *): Symbol = {
create("ElementWiseSum", inputs.toArray, Map("name" -> name), null)
}

def ElementWiseSum(inputs: Seq[Symbol], name: String): Symbol = {
create("ElementWiseSum", inputs.toArray, Map("name" -> name), null)
}

def Concat(inputs: Seq[Symbol],
paramKwargs: Map[String, Any],
attr: Map[String, String] = null): Symbol = {
create("Concat", inputs.toArray,
paramKwargs.map { case (k, v) => (k, v.toString) }, attr)
}

// Use Logistic regression for final output, this is used on final output of a net.
// Logistic regression is suitable for binary classification or probability prediction tasks.
def LogisticRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = {
create("LogisticRegressionOutput", inputs.toArray, null, attr)
}

// Use linear regression for final output, this is used on final output of a net.
def LinearRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = {
create("LinearRegressionOutput", inputs.toArray, null, attr)
}

/**
* Apply swapaxis to input.
* @param data Input data to the SwapAxisOp.
* @param dim1 (non-negative), default=0, the first axis to be swapped.
* @param dim2 (non-negative), default=0, the second axis to be swapped.
*/
def SwapAxis(data: Symbol, dim1: Int = 0, dim2: Int = 0,
attr: Map[String, String] = null): Symbol = {
createNoCheck("SwapAxis")(Map("data" -> data, "dim1" -> dim1, "dim2" -> dim2))
}

/**
* Create a symbol that groups symbols together.
* @param symbols List of symbols to be grouped.
Expand Down
16 changes: 16 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ml.dmlc.mxnet

object CheckUtils {
def reldiff(a: NDArray, b: NDArray): Float = {
val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
val norm = NDArray.sum(NDArray.abs(a)).toScalar
diff / norm
}

def reldiff(a: Array[Float], b: Array[Float]): Float = {
val diff =
(a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum
val norm: Float = a.reduce(Math.abs(_) + Math.abs(_))
diff / norm
}
}
Loading

0 comments on commit 67dc44c

Please sign in to comment.