From 2302420c2c68f3984015259f3e25a31e01c43559 Mon Sep 17 00:00:00 2001 From: Piyush Ghai Date: Thu, 10 Jan 2019 15:41:45 -0800 Subject: [PATCH] [MXNET-1263] Unit Tests for Java Predictor and Object Detector APIs (#13794) * Added unit tests for Predictor API in Java * Added unit tests for ObjectDetectorOutput * Added unit tests for ObjectDetector API in Java * Addressed PR comments * Added Maven SureFire plugin to run the Java UTs * Pom file clean up -- moved surefire plugin to parent pom.xml * Renamed skipTests to SkipJavaTests --- scala-package/core/pom.xml | 14 +-- scala-package/infer/pom.xml | 12 ++ .../javaapi/ObjectDetectorOutputTest.java | 59 ++++++++++ .../infer/javaapi/ObjectDetectorTest.java | 106 ++++++++++++++++++ .../mxnet/infer/javaapi/PredictorTest.java | 100 +++++++++++++++++ scala-package/pom.xml | 7 +- 6 files changed, 287 insertions(+), 11 deletions(-) create mode 100644 scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java create mode 100644 scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java create mode 100644 scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 8a54890ac946..7264c39e84a0 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -13,6 +13,10 @@ mxnet-core MXNet Scala Package - Core + + false + + @@ -115,16 +119,6 @@ - - org.apache.maven.plugins - maven-surefire-plugin - 2.22.0 - - - -Djava.library.path=${project.parent.basedir}/native/target - - - org.scalastyle scalastyle-maven-plugin diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml index 68b76c8b65ad..13ceebb83cd9 100644 --- a/scala-package/infer/pom.xml +++ b/scala-package/infer/pom.xml @@ -10,6 +10,10 @@ ../pom.xml + + false + + mxnet-infer MXNet Scala Package - Inference @@ -60,5 +64,13 @@ 1.10.19 test + + + junit + junit + 4.11 + test + + diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java new file mode 100644 index 000000000000..04041fcda9bf --- /dev/null +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.infer.javaapi; + +import org.junit.Assert; +import org.junit.Test; + +public class ObjectDetectorOutputTest { + + private String predictedClassName = "lion"; + + private float delta = 0.00001f; + + @Test + public void testConstructor() { + + float[] arr = new float[]{0f, 1f, 2f, 3f, 4f}; + + ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr); + + Assert.assertEquals(odOutput.getClassName(), predictedClassName); + Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta); + + } + + @Test (expected = ArrayIndexOutOfBoundsException.class) + public void testIncompleteArgsConstructor() { + + float[] arr = new float[]{0f, 1f}; + + ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr); + + Assert.assertEquals(odOutput.getClassName(), predictedClassName); + Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta); + + // This is where exception will be thrown + odOutput.getXMax(); + } +} diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java new file mode 100644 index 000000000000..a5e64911d141 --- /dev/null +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.infer.javaapi; + +import org.apache.mxnet.Layout; +import org.apache.mxnet.javaapi.DType; +import org.apache.mxnet.javaapi.DataDesc; +import org.apache.mxnet.javaapi.NDArray; +import org.apache.mxnet.javaapi.Shape; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.awt.image.BufferedImage; +import java.util.ArrayList; +import java.util.List; + +public class ObjectDetectorTest { + + private List inputDesc; + private BufferedImage inputImage; + + private List> expectedResult; + + private ObjectDetector objectDetector; + + private int batchSize = 1; + + private int channels = 3; + + private int imageHeight = 512; + + private int imageWidth = 512; + + private String dataName = "data"; + + private int topK = 5; + + private String predictedClassName = "lion"; // Random string + + private Shape getTestShape() { + + return new Shape(new int[] {batchSize, channels, imageHeight, imageWidth}); + } + + @Before + public void setUp() { + + inputDesc = new ArrayList<>(); + inputDesc.add(new DataDesc(dataName, getTestShape(), DType.Float32(), Layout.NCHW())); + inputImage = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB); + objectDetector = Mockito.mock(ObjectDetector.class); + expectedResult = new ArrayList<>(); + expectedResult.add(new ArrayList()); + expectedResult.get(0).add(new ObjectDetectorOutput(predictedClassName, new float[]{})); + } + + @Test + public void testObjectDetectorWithInputImage() { + + Mockito.when(objectDetector.imageObjectDetect(inputImage, topK)).thenReturn(expectedResult); + List> actualResult = objectDetector.imageObjectDetect(inputImage, topK); + Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, topK); + Assert.assertEquals(expectedResult, actualResult); + } + + + @Test + public void testObjectDetectorWithBatchImage() { + + List batchImage = new ArrayList<>(); + batchImage.add(inputImage); + Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult); + List> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK); + Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK); + Assert.assertEquals(expectedResult, actualResult); + } + + @Test + public void testObjectDetectorWithNDArrayInput() { + + NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape()); + List inputL = new ArrayList<>(); + inputL.add(inputArr); + Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult); + List> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK); + Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK); + Assert.assertEquals(expectedResult, actualResult); + } +} diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java new file mode 100644 index 000000000000..e7a6c9652346 --- /dev/null +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.infer.javaapi; + +import org.apache.mxnet.javaapi.Context; +import org.apache.mxnet.javaapi.NDArray; +import org.apache.mxnet.javaapi.Shape; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class PredictorTest { + + Predictor mockPredictor; + + @Before + public void setUp() { + mockPredictor = Mockito.mock(Predictor.class); + } + + @Test + public void testPredictWithFloatArray() { + + float tmp[][] = new float[1][224]; + for (int x = 0; x < 1; x++) { + for (int y = 0; y < 224; y++) + tmp[x][y] = (int) (Math.random() * 10); + } + + float [][] expectedResult = new float[][] {{1f, 2f}}; + Mockito.when(mockPredictor.predict(tmp)).thenReturn(expectedResult); + float[][] actualResult = mockPredictor.predict(tmp); + + Mockito.verify(mockPredictor, Mockito.times(1)).predict(tmp); + Assert.assertArrayEquals(expectedResult, actualResult); + } + + @Test + public void testPredictWithNDArray() { + + float[] tmpArr = new float[224]; + for (int y = 0; y < 224; y++) + tmpArr[y] = (int) (Math.random() * 10); + + NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0)); + + List inputList = new ArrayList<>(); + inputList.add(arr); + + NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0)); + List expectedResult = new ArrayList<>(); + expectedResult.add(expected); + + Mockito.when(mockPredictor.predictWithNDArray(inputList)).thenReturn(expectedResult); + + List actualOutput = mockPredictor.predictWithNDArray(inputList); + + Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputList); + + Assert.assertEquals(expectedResult, actualOutput); + } + + @Test + public void testPredictWithListOfFloatsAsInput() { + List> input = new ArrayList<>(); + + input.add(Arrays.asList(new Float[] {1f, 2f})); + + List> expectedOutput = new ArrayList<>(input); + + Mockito.when(mockPredictor.predict(input)).thenReturn(expectedOutput); + + List> actualOutput = mockPredictor.predict(input); + + Mockito.verify(mockPredictor, Mockito.times(1)).predict(input); + + Assert.assertEquals(expectedOutput, actualOutput); + + } +} \ No newline at end of file diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 1ea898013951..6665e953dcd1 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -42,6 +42,7 @@ g++ $ ${project.basedir}/.. + true pom @@ -228,8 +229,12 @@ org.apache.maven.plugins maven-surefire-plugin - 2.19 + 2.22.0 + ${skipJavaTests} + + -Djava.library.path=${project.parent.basedir}/native/target + false