Skip to content

Commit

Permalink
Adding basic unit tests, bug fixes, and expanding some Java API class…
Browse files Browse the repository at this point in the history
…es (apache#4)

* Adding basic unit tests, bug fixes, and expanding some Java API classes

* Moved pom skipTests change into core
  • Loading branch information
andrewfayres committed Oct 5, 2018
1 parent 3861a90 commit ac380c8
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 1 deletion.
14 changes: 14 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@
</argLine>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<configuration>
<skipTests>false</skipTests>
</configuration>
</plugin>
<plugin>
<groupId>org.scalastyle</groupId>
<artifactId>scalastyle-maven-plugin</artifactId>
Expand All @@ -104,6 +112,12 @@
<version>1.3.1-SNAPSHOT</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
*/
package org.apache.mxnet.javaapi

import collection.JavaConverters._

class Context(val context: org.apache.mxnet.Context) {

val deviceTypeid: Int = context.deviceTypeid

def this(deviceTypeName: String, deviceId: Int = 0)
Expand All @@ -34,5 +37,13 @@ class Context(val context: org.apache.mxnet.Context) {
object Context {
implicit def fromContext(context: org.apache.mxnet.Context): Context = new Context(context)


implicit def toContext(jContext: Context): org.apache.mxnet.Context = jContext.context

val cpu:Context = org.apache.mxnet.Context.cpu()
val gpu:Context = org.apache.mxnet.Context.gpu()
val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
val devstr2type = org.apache.mxnet.Context.devstr2type.asJava

def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ object DataDesc{
implicit def fromDataDesc(dataDesc: org.apache.mxnet.DataDesc): DataDesc = new DataDesc(dataDesc)

implicit def toDataDesc(dataDesc: DataDesc): org.apache.mxnet.DataDesc = dataDesc.dataDesc

def getBatchAxis(layout: String): Int = org.apache.mxnet.DataDesc.getBatchAxis(Some(layout));
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Shape(val shape: org.apache.mxnet.Shape) {
def head: Int = shape.head

def toArray: Array[Int] = shape.toArray
def toVector: Vector[Int] = shape.toVector
def toVector: java.util.List[Int] = shape.toVector.asJava

override def toString(): String = shape.toString
override def equals(o: Any): Boolean = shape.equals(o)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi;

import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class JavaContextTest {

@Test
public void testCPU() {
Context.cpu();
}

@Test
public void testDefault() {
Context.defaultCtx();
}

@Test
public void testConstructor() {
new Context("cpu", 0);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;

public class JavaShapeTest {
@Test
public void testArrayConstructor()
{
new Shape(new int[] {3, 4, 5});
}

@Test
public void testListConstructor()
{
ArrayList<Integer> arrList = new ArrayList<Integer>();
arrList.add(3);
arrList.add(4);
arrList.add(5);
new Shape(arrList);
}

@Test
public void testApply()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertEquals(jS.apply(1), 4);
}

@Test
public void testGet()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertEquals(jS.get(1), 4);
}

@Test
public void testSize()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertEquals(jS.size(), 3);
}

@Test
public void testLength()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertEquals(jS.length(), 3);
}

@Test
public void testDrop()
{
Shape jS = new Shape(new int[] {3, 4, 5});
ArrayList<Integer> l = new ArrayList<Integer>();
l.add(4);
l.add(5);
assertTrue(jS.drop(1).toVector().equals(l));
}

@Test
public void testSlice()
{
Shape jS = new Shape(new int[] {3, 4, 5});
ArrayList<Integer> l = new ArrayList<Integer>();
l.add(4);
assertTrue(jS.slice(1,2).toVector().equals(l));
}

@Test
public void testProduct()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertEquals(jS.product(), 60);
}

@Test
public void testHead()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertEquals(jS.head(), 3);
}

@Test
public void testToArray()
{
Shape jS = new Shape(new int[] {3, 4, 5});
assertTrue(Arrays.equals(jS.toArray(), new int[] {3,4,5}));
}

@Test
public void testToVector()
{
Shape jS = new Shape(new int[] {3, 4, 5});
ArrayList<Integer> l = new ArrayList<Integer>();
l.add(3);
l.add(4);
l.add(5);
assertTrue(jS.toVector().equals(l));
}
}

0 comments on commit ac380c8

Please sign in to comment.