From ac380c84b36f03492793af44dae423e39d0e2ae3 Mon Sep 17 00:00:00 2001 From: Andrew Ayres Date: Fri, 5 Oct 2018 14:50:53 -0700 Subject: [PATCH] Adding basic unit tests, bug fixes, and expanding some Java API classes (#4) * Adding basic unit tests, bug fixes, and expanding some Java API classes * Moved pom skipTests change into core --- scala-package/core/pom.xml | 14 ++ .../org/apache/mxnet/javaapi/Context.scala | 11 ++ .../scala/org/apache/mxnet/javaapi/IO.scala | 2 + .../org/apache/mxnet/javaapi/Shape.scala | 2 +- .../mxnet/api/java/JavaContextTest.java | 40 ++++++ .../apache/mxnet/api/java/JavaShapeTest.java | 121 ++++++++++++++++++ 6 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 0ee749419655..ea3a2d68c9f4 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -81,6 +81,14 @@ + + org.apache.maven.plugins + maven-surefire-plugin + 2.22.0 + + false + + org.scalastyle scalastyle-maven-plugin @@ -104,6 +112,12 @@ 1.3.1-SNAPSHOT provided + + junit + junit + 4.11 + test + commons-io commons-io diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala index a8e4733608cc..acae2d6466db 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala @@ -16,7 +16,10 @@ */ package org.apache.mxnet.javaapi +import collection.JavaConverters._ + class Context(val context: org.apache.mxnet.Context) { + val deviceTypeid: Int = context.deviceTypeid def this(deviceTypeName: String, deviceId: Int = 0) @@ -34,5 +37,13 @@ class Context(val context: org.apache.mxnet.Context) { object Context { implicit def fromContext(context: org.apache.mxnet.Context): Context = new Context(context) + implicit def toContext(jContext: Context): org.apache.mxnet.Context = jContext.context + + val cpu:Context = org.apache.mxnet.Context.cpu() + val gpu:Context = org.apache.mxnet.Context.gpu() + val devtype2str = org.apache.mxnet.Context.devstr2type.asJava + val devstr2type = org.apache.mxnet.Context.devstr2type.asJava + + def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx } \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala index 391d076e0fd1..e669dd052b3a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/IO.scala @@ -29,4 +29,6 @@ object DataDesc{ implicit def fromDataDesc(dataDesc: org.apache.mxnet.DataDesc): DataDesc = new DataDesc(dataDesc) implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc + + def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout)); } \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala index 5dad83b82724..a9a31d9ba1e4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Shape.scala @@ -37,7 +37,7 @@ class Shape(val shape: org.apache.mxnet.Shape) { def head: Int = shape.head def toArray: Array[Int] = shape.toArray - def toVector: Vector[Int] = shape.toVector + def toVector: java.util.List[Int] = shape.toVector.asJava override def toString(): String = shape.toString override def equals(o: Any): Boolean = shape.equals(o) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java new file mode 100644 index 000000000000..b00346cdd972 --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaContextTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class JavaContextTest { + + @Test + public void testCPU() { + Context.cpu(); + } + + @Test + public void testDefault() { + Context.defaultCtx(); + } + + @Test + public void testConstructor() { + new Context("cpu", 0); + } +} \ No newline at end of file diff --git a/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java new file mode 100644 index 000000000000..38ea24783efa --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/api/java/JavaShapeTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import org.junit.Test; +import java.util.ArrayList; +import java.util.Arrays; + +public class JavaShapeTest { + @Test + public void testArrayConstructor() + { + new Shape(new int[] {3, 4, 5}); + } + + @Test + public void testListConstructor() + { + ArrayList arrList = new ArrayList(); + arrList.add(3); + arrList.add(4); + arrList.add(5); + new Shape(arrList); + } + + @Test + public void testApply() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.apply(1), 4); + } + + @Test + public void testGet() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.get(1), 4); + } + + @Test + public void testSize() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.size(), 3); + } + + @Test + public void testLength() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.length(), 3); + } + + @Test + public void testDrop() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + ArrayList l = new ArrayList(); + l.add(4); + l.add(5); + assertTrue(jS.drop(1).toVector().equals(l)); + } + + @Test + public void testSlice() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + ArrayList l = new ArrayList(); + l.add(4); + assertTrue(jS.slice(1,2).toVector().equals(l)); + } + + @Test + public void testProduct() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.product(), 60); + } + + @Test + public void testHead() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertEquals(jS.head(), 3); + } + + @Test + public void testToArray() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5})); + } + + @Test + public void testToVector() + { + Shape jS = new Shape(new int[] {3, 4, 5}); + ArrayList l = new ArrayList(); + l.add(3); + l.add(4); + l.add(5); + assertTrue(jS.toVector().equals(l)); + } +} \ No newline at end of file