Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Java/Scala modification (#14625)
Browse files Browse the repository at this point in the history
* modify jni to support 0 dim/shape

* fix transpose axes default value
  • Loading branch information
yzhliu authored and reminisce committed Apr 6, 2019
1 parent 58beb06 commit dc7cf99
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 142 deletions.
142 changes: 48 additions & 94 deletions scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
protected var monitorCallback: MXMonitorCallback = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])

private[mxnet] var ownsArgArrays = false
private[mxnet] var ownsGradArrays = false
private[mxnet] var ownsAuxArrays = false

override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
Expand All @@ -75,17 +71,12 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
// Symbol.bind clones symbol when creating the executor so we need to dispose of the clone
symbol.dispose()
if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => a.dispose())}
if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
if (argArrays != null) {argArrays.foreach(a => a.dispose())}
if (gradArrays != null) {gradArrays.foreach(
// Symbol will sometimes fill this with nulls so we've got to check the elements too
a => if (a != null) {a.dispose()})
}
if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => a.dispose())}
if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
if (auxArrays != null) {auxArrays.foreach(a => a.dispose())}
}
}

Expand All @@ -104,95 +95,58 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
var setArgOwner = false
var setAuxOwner = false
var setGradOwner = false
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
// TODO: more precise error message should be provided by backend
require(argShapes != null, "Shape inference failed." +
s"Known shapes are $kwargs for symbol arguments ${symbol.listArguments()} " +
s"and aux states ${symbol.listAuxiliaryStates()}")

var newArgDict = Map[String, NDArray]()
var newGradDict = Map[String, NDArray]()

this.symbol.listArguments().zipWithIndex.foreach { case (name, i) =>
val newShape = argShapes(i)
val arr = this.argArrays(i)
val dArr = if (this.gradArrays == null) null else this.gradArrays(i)
if (partialShaping || kwargs.contains(name) || newShape.equals(arr.shape)) {
if (newShape.product > arr.shape.product) {
require(allowUpSizing, s"New shape of arg:$name larger than original. " +
"First making a big executor and then down sizing it " +
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newArgDict = newArgDict + (name -> NDArray.empty(newShape, arr.context, arr.dtype))
setArgOwner = true
if (dArr != null) {
newGradDict = newGradDict + (name -> NDArray.empty(newShape, dArr.context, dArr.dtype))
setGradOwner = true
}
} else {
newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
if (dArr != null) {
newGradDict = newGradDict + (name -> dArr.reshape(newShape.toArray))
}
}
} else {
throw new AssertionError(s"Shape of unspecified array arg:$name changed." +
"This can cause the new executor to not share parameters " +
"with the old one. Please check for error in network." +
"If this is intended, set partialShaping = true to suppress this warning.")
}
}
val providedArgShapeNames = kwargs.keys
val providedArgShapeData = kwargs.values.flatMap(_.toVector)
val providedArgShapeIdx = kwargs.values.scanLeft(0)((sum, shape) => sum + shape.size)

var newAuxDict = Map[String, NDArray]()
val zip3 = (this.symbol.listAuxiliaryStates(), auxShapes, this.auxArrays).zipped
zip3.foreach { case (name, newShape, arr) =>
if (partialShaping || newShape.equals(arr.shape)) {
if (newShape.product > arr.shape.product) {
require(allowUpSizing, s"New shape of aux:$name larger than original. " +
"First making a big executor and then down sizing it " +
"is more efficient than the reverse." +
"If you really want to up size, set allowUpSizing = true " +
"to enable allocation of new arrays.")
newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, arr.context))
setAuxOwner = true
} else {
newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
}
} else {
throw new AssertionError(s"Shape of unspecified array aux:$name changed." +
"This can cause the new executor to not share parameters " +
"with the old one. Please check for error in network." +
"If this is intended, set partialShaping = true to suppress this warning.")
}
val ctxMapKeys = if (_group2ctx != null) _group2ctx.keys.toArray else Array.empty[String]
val ctxMapDevTypes = if (_group2ctx != null) {
_group2ctx.values.map(_.deviceTypeid).toArray
} else {
Array.empty[Int]
}
val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
this._gradsReq.asInstanceOf[Seq[String]],
newAuxDict,
this._group2ctx,
this)
val ctxMapDevIds = if (_group2ctx != null) {
_group2ctx.values.map(_.deviceId).toArray
} else {
this.symbol.bind(this._ctx,
newArgDict,
newGradDict,
this._gradsReq.asInstanceOf[Map[String, String]],
newAuxDict,
this._group2ctx,
this)
Array.empty[Int]
}

// This method has created new NDArrays that will need to be managed by the new Executor
if (setArgOwner) reshapedExecutor.ownsArgArrays = true
if (setGradOwner) reshapedExecutor.ownsGradArrays = true
if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true
val inArgs = ArrayBuffer.empty[NDArrayHandle]
val argGrads = ArrayBuffer.empty[NDArrayHandle]
val auxStates = ArrayBuffer.empty[NDArrayHandle]
val outHandle = new ExecutorHandleRef()

checkCall(_LIB.mxExecutorReshape(
if (partialShaping) 1 else 0,
if (allowUpSizing) 1 else 0,
_ctx.deviceTypeid,
_ctx.deviceId,
ctxMapKeys.toArray,
ctxMapDevTypes.toArray,
ctxMapDevIds.toArray,
providedArgShapeNames.toArray,
providedArgShapeData.toArray,
providedArgShapeIdx.toArray,
inArgs,
argGrads,
auxStates,
this.handle,
outHandle))

val argArrays = inArgs.map(new NDArray(_)).toArray
val gradArrays = argGrads.map(handle =>
if (handle == 0) null else new NDArray(handle)).toArray
val auxArrays = auxStates.map(new NDArray(_)).toArray

reshapedExecutor
val executor = new Executor(outHandle.value, this.symbol)
executor._ctx = this._ctx
executor._gradsReq = this._gradsReq
executor._group2ctx = this._group2ctx
executor.argArrays = argArrays
executor.gradArrays = gradArrays
executor.auxArrays = auxArrays
executor
}

/**
Expand Down
32 changes: 31 additions & 1 deletion scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,23 @@ private[mxnet] class LibInfo {
grads: Array[NDArrayHandle]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int
// scalastyle:off parameterNum
@native def mxExecutorReshape(partialShaping: Int,
allowUpSizing: Int,
devType: Int,
devId: Int,
mapKeys: Array[String],
mapDevTypes: Array[Int],
mapDevIds: Array[Int],
providedArgShapeNames: Array[String],
providedArgShapeData: Array[Int],
providedArgShapeIdx: Array[Int],
inArgs: ArrayBuffer[NDArrayHandle],
argGrads: ArrayBuffer[NDArrayHandle],
auxStates: ArrayBuffer[NDArrayHandle],
sharedExec: ExecutorHandle,
out: ExecutorHandleRef): Int
// scalastyle:on parameterNum

// Symbols
@native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int
Expand Down Expand Up @@ -240,11 +257,20 @@ private[mxnet] class LibInfo {
numArgs: MXUint,
keys: Array[String],
argIndPtr: Array[MXUint],
argShapeData: Array[MXUint],
argShapeData: Array[Int],
inShapeData: ListBuffer[Array[Int]],
outShapeData: ListBuffer[Array[Int]],
auxShapeData: ListBuffer[Array[Int]],
complete: RefInt): Int
@native def mxSymbolInferShapePartial(handle: SymbolHandle,
numArgs: MXUint,
keys: Array[String],
argIndPtr: Array[MXUint],
argShapeData: Array[Int],
inShapeData: ListBuffer[Array[Int]],
outShapeData: ListBuffer[Array[Int]],
auxShapeData: ListBuffer[Array[Int]],
complete: RefInt): Int
@native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int
@native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int
@native def mxSymbolCreateFromJSON(json: String, handle: SymbolHandleRef): Int
Expand Down Expand Up @@ -322,4 +348,8 @@ private[mxnet] class LibInfo {
@native def mxSetProfilerConfig(keys: Array[String], vals: Array[String]): Int
@native def mxSetProfilerState(state: Int): Int
@native def mxDumpProfile(finished: Int): Int

// Numpy
@native def mxIsNumpyCompatible(compatible: RefInt): Int
@native def mxSetIsNumpyCompatible(isNpComp: Int, prev: RefInt): Int
}
10 changes: 7 additions & 3 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1274,11 +1274,15 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @return an array representing shape of current ndarray
*/
def shape: Shape = {
val ndim = new MXUintRef
val ndim = new RefInt
val data = ArrayBuffer[Int]()
checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
require(ndim.value == data.length, s"ndim=$ndim, while len(data)=${data.length}")
Shape(data)
if (ndim.value == -1) {
null
} else {
require(ndim.value == data.length, s"ndim=$ndim, while len(data)=${data.length}")
Shape(data)
}
}

// Get size of current NDArray.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

import org.apache.mxnet.Base._

object NumpyScope {
def setNumpyCompatible(isNpComp: Boolean): Boolean = {
val prev = new RefInt()
checkCall(_LIB.mxSetIsNumpyCompatible(if (isNpComp) 1 else 0, prev))
if (prev.value != 0) true else false
}

def isNumpyCompatible: Boolean = {
val curr = new RefInt
checkCall(_LIB.mxIsNumpyCompatible(curr))
if (curr.value != 0) true else false
}

def enableNumpyCompatible: NumpyScope = {
new NumpyScope(true)
}


def disableNumpyCompatible: NumpyScope = {
new NumpyScope(false)
}
}

class NumpyScope(var isCompatible: Boolean) {
private var prev: Boolean = false

def withScope[T](body: => T): T = {
prev = NumpyScope.setNumpyCompatible(isCompatible)
try {
body
} finally {
if (prev != isCompatible) {
NumpyScope.setNumpyCompatible(prev)
}
}
}
}
38 changes: 33 additions & 5 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,45 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso

def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int])
: (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
val res = inferShapeImpl(partial = false, keys, indPtr, values)
if (res._2 == null) {
val (argShapes, _, _) = inferShapeImpl(partial = true, keys, indPtr, values)
val argNames = listArguments()
val unknown = (argNames zip argShapes).map { case (name, shape) =>
val shapeIsNone = if (NumpyScope.isNumpyCompatible) {
shape == null || shape.toVector.contains(-1)
} else {
shape == null || shape.toVector.contains(0)
}
if (shapeIsNone) s"$name: $shape" else ""
}
logger.warn("Cannot decide shape for the following arguments. " +
"Consider providing them as input: \n\t{}",
unknown.filter(_ != "").mkString("\n\t"))
}
res
}

private def inferShapeImpl(partial: Boolean,
keys: Array[String],
indPtr: Array[Int],
values: Array[Int])
: (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
val argShapeData = ListBuffer.empty[Array[Int]]
val outShapeData = ListBuffer.empty[Array[Int]]
val auxShapeData = ListBuffer.empty[Array[Int]]
val complete = new RefInt

checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, indPtr, values,
argShapeData, outShapeData, auxShapeData, complete))
if (partial) {
checkCall(_LIB.mxSymbolInferShapePartial(handle, indPtr.length - 1, keys, indPtr, values,
argShapeData, outShapeData, auxShapeData, complete))
} else {
checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, indPtr, values,
argShapeData, outShapeData, auxShapeData, complete))
}
if (complete.value != 0) {
(argShapeData.map(s => Shape(s)).toIndexedSeq,
outShapeData.map(s => Shape(s)).toIndexedSeq,
auxShapeData.map(s => Shape(s)).toIndexedSeq)
outShapeData.map(s => Shape(s)).toIndexedSeq,
auxShapeData.map(s => Shape(s)).toIndexedSeq)
} else {
(null, null, null)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite}

class NumpyScopeSuite extends FunSuite with BeforeAndAfterAll {
test("compatible") {
NumpyScope.enableNumpyCompatible.withScope {
assert(NumpyScope.isNumpyCompatible === true)
}
}

test("incompatible") {
NumpyScope.disableNumpyCompatible.withScope {
assert(NumpyScope.isNumpyCompatible === false)
}
}
}
Loading

0 comments on commit dc7cf99

Please sign in to comment.