Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1293] Adding Iterables instead of List to method signature for…
Browse files Browse the repository at this point in the history
… infer APIs in Java (#13977)

* Added Iterables as input type instead of List in Predictor for Java

* Added Iterables to ObjectDetector API

* Added tests for Predictor API

* Added tests for ObjectDetector
  • Loading branch information
piyushghai authored and lanking520 committed Jan 24, 2019
1 parent 4700b40 commit 24412df
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ import scala.language.implicitConversions
*/
class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){

def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
java.util.List[Context], epoch: Int)
def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts:
java.lang.Iterable[Context], epoch: Int)
= this {
val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
.asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
Expand Down Expand Up @@ -79,7 +79,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
* @return List of list of tuples of
* (class, [probability, xmin, ymin, xmax, ymax])
*/
def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
def objectDetectWithNDArray(input: java.lang.Iterable[NDArray], topK: Int):
java.util.List[java.util.List[ObjectDetectorOutput]] = {
val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
Expand All @@ -92,7 +92,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
* @param topK Number of result elements to return, sorted by probability
* @return List of list of tuples of (class, probability)
*/
def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
def imageBatchObjectDetect(inputBatch: java.lang.Iterable[BufferedImage], topK: Int):
java.util.List[java.util.List[ObjectDetectorOutput]] = {
val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
Expand Down Expand Up @@ -122,7 +122,7 @@ object ObjectDetector {
org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
}

def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = {
org.apache.mxnet.infer.ImageClassifier
.loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import scala.collection.JavaConverters._

// JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178
class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor){
def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
contexts: java.util.List[Context], epoch: Int)
def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc],
contexts: java.lang.Iterable[Context], epoch: Int)
= this {
val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
.asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
Expand Down Expand Up @@ -97,10 +97,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
}

/**
* Takes input as List of one dimensional arrays and creates the NDArray needed for inference
* Takes input as List of one dimensional iterables and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
*
* @param input: A List of a one-dimensional array.
* @param input: A List of a one-dimensional iterables of DType Float.
An extra List is needed for when the model has more than one input.
* @return Indexed sequence array of outputs
*/
Expand All @@ -118,10 +118,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
* This method is useful when the input is a batch of data
* Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
*
* @param input List of NDArrays
* @param input Iterable of NDArrays
* @return Output of predictions as NDArrays
*/
def predictWithNDArray(input: java.util.List[NDArray]):
def predictWithNDArray(input: java.lang.Iterable[NDArray]):
java.util.List[NDArray] = {
val ret = predictor.predictWithNDArray(convert(JavaConverters
.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class ObjectDetectorTest {

Expand Down Expand Up @@ -92,6 +94,17 @@ public void testObjectDetectorWithBatchImage() {
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithIterableOfBatchImage() {

Set<BufferedImage> batchImage = new HashSet<>();
batchImage.add(inputImage);
Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithNDArrayInput() {

Expand All @@ -103,4 +116,16 @@ public void testObjectDetectorWithNDArrayInput() {
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithIterableOfNDArrayInput() {

NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
Set<NDArray> inputL = new HashSet<>();
inputL.add(inputArr);
Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
import org.junit.Test;
import org.mockito.Mockito;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;

public class PredictorTest {

Expand Down Expand Up @@ -80,6 +78,31 @@ public void testPredictWithNDArray() {
Assert.assertEquals(expectedResult, actualOutput);
}

@Test
public void testPredictWithIterablesNDArray() {

float[] tmpArr = new float[224];
for (int y = 0; y < 224; y++)
tmpArr[y] = (int) (Math.random() * 10);

NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));

Set<NDArray> inputSet = new HashSet<>();
inputSet.add(arr);

NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
List<NDArray> expectedResult = new ArrayList<>();
expectedResult.add(expected);

Mockito.when(mockPredictor.predictWithNDArray(inputSet)).thenReturn(expectedResult);

List<NDArray> actualOutput = mockPredictor.predictWithNDArray(inputSet);

Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputSet);

Assert.assertEquals(expectedResult, actualOutput);
}

@Test
public void testPredictWithListOfFloatsAsInput() {
List<List<Float>> input = new ArrayList<>();
Expand Down

0 comments on commit 24412df

Please sign in to comment.