diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 8a54890ac946..7264c39e84a0 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -13,6 +13,10 @@
mxnet-core
MXNet Scala Package - Core
+
+ false
+
+
@@ -115,16 +119,6 @@
-
- org.apache.maven.plugins
- maven-surefire-plugin
- 2.22.0
-
-
- -Djava.library.path=${project.parent.basedir}/native/target
-
-
-
org.scalastyle
scalastyle-maven-plugin
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
index 68b76c8b65ad..13ceebb83cd9 100644
--- a/scala-package/infer/pom.xml
+++ b/scala-package/infer/pom.xml
@@ -10,6 +10,10 @@
../pom.xml
+
+ false
+
+
mxnet-infer
MXNet Scala Package - Inference
@@ -60,5 +64,13 @@
1.10.19
test
+
+
+ junit
+ junit
+ 4.11
+ test
+
+
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
new file mode 100644
index 000000000000..04041fcda9bf
--- /dev/null
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ObjectDetectorOutputTest {
+
+ private String predictedClassName = "lion";
+
+ private float delta = 0.00001f;
+
+ @Test
+ public void testConstructor() {
+
+ float[] arr = new float[]{0f, 1f, 2f, 3f, 4f};
+
+ ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr);
+
+ Assert.assertEquals(odOutput.getClassName(), predictedClassName);
+ Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);
+
+ }
+
+ @Test (expected = ArrayIndexOutOfBoundsException.class)
+ public void testIncompleteArgsConstructor() {
+
+ float[] arr = new float[]{0f, 1f};
+
+ ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr);
+
+ Assert.assertEquals(odOutput.getClassName(), predictedClassName);
+ Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
+ Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
+
+ // This is where exception will be thrown
+ odOutput.getXMax();
+ }
+}
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
new file mode 100644
index 000000000000..a5e64911d141
--- /dev/null
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi;
+
+import org.apache.mxnet.Layout;
+import org.apache.mxnet.javaapi.DType;
+import org.apache.mxnet.javaapi.DataDesc;
+import org.apache.mxnet.javaapi.NDArray;
+import org.apache.mxnet.javaapi.Shape;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.awt.image.BufferedImage;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ObjectDetectorTest {
+
+ private List inputDesc;
+ private BufferedImage inputImage;
+
+ private List> expectedResult;
+
+ private ObjectDetector objectDetector;
+
+ private int batchSize = 1;
+
+ private int channels = 3;
+
+ private int imageHeight = 512;
+
+ private int imageWidth = 512;
+
+ private String dataName = "data";
+
+ private int topK = 5;
+
+ private String predictedClassName = "lion"; // Random string
+
+ private Shape getTestShape() {
+
+ return new Shape(new int[] {batchSize, channels, imageHeight, imageWidth});
+ }
+
+ @Before
+ public void setUp() {
+
+ inputDesc = new ArrayList<>();
+ inputDesc.add(new DataDesc(dataName, getTestShape(), DType.Float32(), Layout.NCHW()));
+ inputImage = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB);
+ objectDetector = Mockito.mock(ObjectDetector.class);
+ expectedResult = new ArrayList<>();
+ expectedResult.add(new ArrayList());
+ expectedResult.get(0).add(new ObjectDetectorOutput(predictedClassName, new float[]{}));
+ }
+
+ @Test
+ public void testObjectDetectorWithInputImage() {
+
+ Mockito.when(objectDetector.imageObjectDetect(inputImage, topK)).thenReturn(expectedResult);
+ List> actualResult = objectDetector.imageObjectDetect(inputImage, topK);
+ Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, topK);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+
+ @Test
+ public void testObjectDetectorWithBatchImage() {
+
+ List batchImage = new ArrayList<>();
+ batchImage.add(inputImage);
+ Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
+ List> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
+ Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ @Test
+ public void testObjectDetectorWithNDArrayInput() {
+
+ NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
+ List inputL = new ArrayList<>();
+ inputL.add(inputArr);
+ Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
+ List> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
+ Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+}
diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
new file mode 100644
index 000000000000..e7a6c9652346
--- /dev/null
+++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/PredictorTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.infer.javaapi;
+
+import org.apache.mxnet.javaapi.Context;
+import org.apache.mxnet.javaapi.NDArray;
+import org.apache.mxnet.javaapi.Shape;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class PredictorTest {
+
+ Predictor mockPredictor;
+
+ @Before
+ public void setUp() {
+ mockPredictor = Mockito.mock(Predictor.class);
+ }
+
+ @Test
+ public void testPredictWithFloatArray() {
+
+ float tmp[][] = new float[1][224];
+ for (int x = 0; x < 1; x++) {
+ for (int y = 0; y < 224; y++)
+ tmp[x][y] = (int) (Math.random() * 10);
+ }
+
+ float [][] expectedResult = new float[][] {{1f, 2f}};
+ Mockito.when(mockPredictor.predict(tmp)).thenReturn(expectedResult);
+ float[][] actualResult = mockPredictor.predict(tmp);
+
+ Mockito.verify(mockPredictor, Mockito.times(1)).predict(tmp);
+ Assert.assertArrayEquals(expectedResult, actualResult);
+ }
+
+ @Test
+ public void testPredictWithNDArray() {
+
+ float[] tmpArr = new float[224];
+ for (int y = 0; y < 224; y++)
+ tmpArr[y] = (int) (Math.random() * 10);
+
+ NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
+
+ List inputList = new ArrayList<>();
+ inputList.add(arr);
+
+ NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
+ List expectedResult = new ArrayList<>();
+ expectedResult.add(expected);
+
+ Mockito.when(mockPredictor.predictWithNDArray(inputList)).thenReturn(expectedResult);
+
+ List actualOutput = mockPredictor.predictWithNDArray(inputList);
+
+ Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputList);
+
+ Assert.assertEquals(expectedResult, actualOutput);
+ }
+
+ @Test
+ public void testPredictWithListOfFloatsAsInput() {
+ List> input = new ArrayList<>();
+
+ input.add(Arrays.asList(new Float[] {1f, 2f}));
+
+ List> expectedOutput = new ArrayList<>(input);
+
+ Mockito.when(mockPredictor.predict(input)).thenReturn(expectedOutput);
+
+ List> actualOutput = mockPredictor.predict(input);
+
+ Mockito.verify(mockPredictor, Mockito.times(1)).predict(input);
+
+ Assert.assertEquals(expectedOutput, actualOutput);
+
+ }
+}
\ No newline at end of file
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 1ea898013951..6665e953dcd1 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -42,6 +42,7 @@
g++
$
${project.basedir}/..
+ true
pom
@@ -228,8 +229,12 @@
org.apache.maven.plugins
maven-surefire-plugin
- 2.19
+ 2.22.0
+ ${skipJavaTests}
+
+ -Djava.library.path=${project.parent.basedir}/native/target
+
false