diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 0ee749419655..ea3a2d68c9f4 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -81,6 +81,14 @@
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.22.0
+
+ false
+
+
org.scalastyle
scalastyle-maven-plugin
@@ -104,6 +112,12 @@
1.3.1-SNAPSHOT
provided
+
+ junit
+ junit
+ 4.11
+ test
+
commons-io
commons-io
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
index a8e4733608cc..acae2d6466db 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -16,7 +16,10 @@
*/
package org.apache.mxnet.javaapi
+import collection.JavaConverters._
+
class Context(val context: org.apache.mxnet.Context) {
+
val deviceTypeid: Int = context.deviceTypeid
def this(deviceTypeName: String, deviceId: Int = 0)
@@ -34,5 +37,13 @@ class Context(val context: org.apache.mxnet.Context) {
object Context {
implicit def fromContext(context: org.apache.mxnet.Context): Context = new Context(context)
+
implicit def toContext(jContext: Context): org.apache.mxnet.Context = jContext.context
+
+ val cpu:Context = org.apache.mxnet.Context.cpu()
+ val gpu:Context = org.apache.mxnet.Context.gpu()
+ val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
+ val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
+
+ def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
\ No newline at end of file
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
index 391d076e0fd1..e669dd052b3a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala
@@ -29,4 +29,6 @@ object DataDesc{
implicit def fromDataDesc(dataDesc: org.apache.mxnet.DataDesc): DataDesc = new DataDesc(dataDesc)
implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc
+
+ def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout));
}
\ No newline at end of file
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
index 5dad83b82724..a9a31d9ba1e4 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala
@@ -37,7 +37,7 @@ class Shape(val shape: org.apache.mxnet.Shape) {
def head: Int = shape.head
def toArray: Array[Int] = shape.toArray
- def toVector: Vector[Int] = shape.toVector
+ def toVector: java.util.List[Int] = shape.toVector.asJava
override def toString(): String = shape.toString
override def equals(o: Any): Boolean = shape.equals(o)
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java
new file mode 100644
index 000000000000..b00346cdd972
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class JavaContextTest {
+
+ @Test
+ public void testCPU() {
+ Context.cpu();
+ }
+
+ @Test
+ public void testDefault() {
+ Context.defaultCtx();
+ }
+
+ @Test
+ public void testConstructor() {
+ new Context("cpu", 0);
+ }
+}
\ No newline at end of file
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java
new file mode 100644
index 000000000000..38ea24783efa
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import org.junit.Test;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+public class JavaShapeTest {
+ @Test
+ public void testArrayConstructor()
+ {
+ new Shape(new int[] {3, 4, 5});
+ }
+
+ @Test
+ public void testListConstructor()
+ {
+ ArrayList arrList = new ArrayList();
+ arrList.add(3);
+ arrList.add(4);
+ arrList.add(5);
+ new Shape(arrList);
+ }
+
+ @Test
+ public void testApply()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.apply(1), 4);
+ }
+
+ @Test
+ public void testGet()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.get(1), 4);
+ }
+
+ @Test
+ public void testSize()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.size(), 3);
+ }
+
+ @Test
+ public void testLength()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.length(), 3);
+ }
+
+ @Test
+ public void testDrop()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ ArrayList l = new ArrayList();
+ l.add(4);
+ l.add(5);
+ assertTrue(jS.drop(1).toVector().equals(l));
+ }
+
+ @Test
+ public void testSlice()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ ArrayList l = new ArrayList();
+ l.add(4);
+ assertTrue(jS.slice(1,2).toVector().equals(l));
+ }
+
+ @Test
+ public void testProduct()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.product(), 60);
+ }
+
+ @Test
+ public void testHead()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertEquals(jS.head(), 3);
+ }
+
+ @Test
+ public void testToArray()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5}));
+ }
+
+ @Test
+ public void testToVector()
+ {
+ Shape jS = new Shape(new int[] {3, 4, 5});
+ ArrayList l = new ArrayList();
+ l.add(3);
+ l.add(4);
+ l.add(5);
+ assertTrue(jS.toVector().equals(l));
+ }
+}
\ No newline at end of file