Skip to content

Commit

Permalink
Merge pull request apache#18 from javelinjs/scala-package-cc
Browse files Browse the repository at this point in the history
Symbol creation
  • Loading branch information
yanqingmen committed Jan 5, 2016
2 parents 887a6b7 + e1aa386 commit 1ab9104
Show file tree
Hide file tree
Showing 10 changed files with 688 additions and 12 deletions.
39 changes: 39 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ml.dmlc.mxnet

/**
* Attribute manager for scoping.
* User can also inherit this object to change naming behavior.
* @author Yizhi Liu
*/
class AttrScope(attr: Map[String, String] = Map.empty) {
private var _attr = attr
/**
* Get the attribute dict given the attribute set by the symbol.
* @param userDefinedAttr The attribute passed in by user during symbol creation.
* @return Updated attributes to add other scope related attributes.
*/
def get(userDefinedAttr: Option[Map[String, String]]): Map[String, String] = {
_attr ++ userDefinedAttr.getOrElse(Map.empty[String, String])
}

def withScope[T](body: => T): T = {
val oldAttrScope = AttrScope.current
this._attr = AttrScope.current._attr ++ this._attr
AttrScope.setCurrentAttr(this)
try {
body
} finally {
AttrScope.setCurrentAttr(oldAttrScope)
}
}
}

object AttrScope {
private var _current = new AttrScope()
def current: AttrScope = _current
private def setCurrentAttr(attr: AttrScope): Unit = {
_current = attr
}

def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr)
}
5 changes: 3 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ object Base {
type MXFloat = Float
type CPtrAddress = Long

type SymbolHandle = CPtrAddress

type MXUintRef = RefInt
type MXFloatRef = RefFloat
type NDArrayHandle = RefLong
Expand All @@ -19,12 +21,11 @@ object Base {
type DataIterCreator = RefLong
type KVStoreHandle = RefLong
type ExecutorHandle = RefLong

type SymbolHandleRef = RefLong

System.loadLibrary("mxnet-scala")
val _LIB = new LibInfo


// helper function definitions
/**
* Check the return value of C API call
Expand Down
2 changes: 1 addition & 1 deletion scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object IO {
/**
* create iterator via iterName and params
* @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
* @param params paramters for create iterator
* @param params parameters for create iterator
* @return
*/
def createIterator(iterName: String, params: Map[String, String]): DataIter = {
Expand Down
32 changes: 32 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,36 @@ class LibInfo {
grads: Array[CPtrAddress]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int

// Symbols
@native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int
@native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle,
name: RefString,
desc: RefString,
numArgs: MXUintRef,
argNames: ListBuffer[String],
argTypes: ListBuffer[String],
argDescs: ListBuffer[String],
keyVarNumArgs: RefString): Int
@native def mxSymbolCreateAtomicSymbol(handle: SymbolHandle,
paramKeys: Array[String],
paramVals: Array[String],
symHandleRef: SymbolHandleRef): Int
@native def mxSymbolSetAttr(handle: SymbolHandle, key: String, value: String): Int
@native def mxSymbolCompose(handle: SymbolHandle,
name: String,
keys: Array[String],
args: Array[SymbolHandle]): Int
@native def mxSymbolCreateVariable(name: String, out: SymbolHandleRef): Int
@native def mxSymbolGetAttr(handle: SymbolHandle,
key: String,
ret: RefString,
success: RefInt): Int
@native def mxSymbolListArguments(handle: SymbolHandle,
arguments: ArrayBuffer[String]): Int
@native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int
@native def mxSymbolListOutputs(handle: SymbolHandle,
outputs: ArrayBuffer[String]): Int
@native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int
@native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int
}
51 changes: 51 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package ml.dmlc.mxnet

import scala.collection.mutable

/**
* NameManager to do automatic naming.
* User can also inherit this object to change naming behavior.
* @author Yizhi Liu
*/
class NameManager {
val counter: mutable.Map[String, Int] = mutable.HashMap.empty[String, Int]
/**
* Get the canonical name for a symbol.
* This is default implementation.
* When user specified a name,
* the user specified name will be used.
* When user did not, we will automatically generate a name based on hint string.
*
* @param name : The name user specified.
* @param hint : A hint string, which can be used to generate name.
* @return A canonical name for the user.
*/
def get(name: Option[String], hint: String): String = {
name.getOrElse {
if (!counter.contains(hint)) {
counter(hint) = 0
}
val generatedName = s"$hint${counter(hint)}"
counter(hint) += 1
generatedName
}
}

def withScope[T](body: => T): T = {
val oldManager = NameManager.current
NameManager.setCurrentManager(this)
try {
body
} finally {
NameManager.setCurrentManager(oldManager)
}
}
}

object NameManager {
private var _current = new NameManager()
def current: NameManager = _current
private def setCurrentManager(manager: NameManager): Unit = {
_current = manager
}
}
Loading

0 comments on commit 1ab9104

Please sign in to comment.