diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 20b6ed9fc806..40fc0951e885 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -86,9 +86,10 @@ private[mxnet] class LibInfo { @native def mxNDArrayAt(handle: NDArrayHandle, idx: MXUint, out: NDArrayHandleRef): Int - @native def mxNDArrayReshape(handle: NDArrayHandle, + @native def mxNDArrayReshape64(handle: NDArrayHandle, nDim: Int, - dims: Array[Int], + dims: Array[Long], + reverse: Boolean, reshapeHandle: NDArrayHandleRef): Int @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, source: Array[MXFloat], diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 915e4c69de31..ab42265ae102 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -950,8 +950,19 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @return a reshaped NDArray that shares memory with current one. */ def reshape(dims: Array[Int]): NDArray = { + reshape(dims.map(_.toLong)) + } + + /** + * Return a reshaped NDArray that shares memory with current one. + * @param dims New shape. + * @param reverse whether to inplace reshape + * @return a reshaped NDArray that shares memory with current one. + */ + def reshape(dims: Array[Long], reverse: Option[Boolean] = None): NDArray = { val reshapeHandle = new NDArrayHandleRef - checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle)) + checkCall(_LIB.mxNDArrayReshape64(handle, + dims.length, dims, reverse.getOrElse(false), reshapeHandle)) new NDArray(handle = reshapeHandle.value, writable = this.writable) } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 206094c15958..c2ef641f9c9a 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -878,14 +878,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("reshape") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2)) + var arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2)) - val arr1 = arr.reshape(Array(2, 3)) + var arr1 = arr.reshape(Array(2, 3)) assert(arr1.shape === Shape(2, 3)) assert(arr1.toArray === Array(1f, 2f, 3f, 4f, 5f, 6f)) arr.set(1f) assert(arr1.toArray === Array(1f, 1f, 1f, 1f, 1f, 1f)) + + arr = NDArray.ones(1, 384, 1) + arr1 = arr.reshape(Array(0, -3)) + assert(arr1.shape === Shape(1, 384)) } test("dispose deps") { diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index ea6e9c8f5ba4..33e4cca99b3a 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -404,14 +404,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt return ret; } -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape - (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jintArray dims, jobject reshapedHandle) { +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64 + (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, + jlongArray dims, jboolean reverse, jobject reshapedHandle) { NDArrayHandle out; - jint *pdims = env->GetIntArrayElements(dims, NULL); - int ret = MXNDArrayReshape(reinterpret_cast(ndArrayPtr), ndim, - reinterpret_cast(pdims), &out); + jlong *pdims = env->GetLongArrayElements(dims, NULL); + int ret = MXNDArrayReshape64(reinterpret_cast(ndArrayPtr), ndim, + reinterpret_cast(pdims), reverse, &out); SetLongField(env, reshapedHandle, reinterpret_cast(out)); - env->ReleaseIntArrayElements(dims, pdims, 0); + env->ReleaseLongArrayElements(dims, pdims, 0); return ret; } diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index 7e8e03de9124..b8a9b3b9e64f 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -161,11 +161,11 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayAt /* * Class: org_apache_mxnet_LibInfo - * Method: mxNDArrayReshape - * Signature: (JI[ILorg/apache/mxnet/Base/RefLong;)I + * Method: mxNDArrayReshape64 + * Signature: (JI[JZLorg/apache/mxnet/Base/RefLong;)I */ -JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape - (JNIEnv *, jobject, jlong, jint, jintArray, jobject); +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64 + (JNIEnv *, jobject, jlong, jint, jlongArray, jboolean, jobject); /* * Class: org_apache_mxnet_LibInfo