Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1293] Adding Iterables instead of List to method signature for infer APIs in Java #13977

Merged
merged 4 commits into from
Jan 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ import scala.language.implicitConversions
*/
class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.ObjectDetector){

def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc], contexts:
java.util.List[Context], epoch: Int)
def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc], contexts:
java.lang.Iterable[Context], epoch: Int)
= this {
val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
.asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
Expand Down Expand Up @@ -79,7 +79,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
* @return List of list of tuples of
* (class, [probability, xmin, ymin, xmax, ymax])
*/
def objectDetectWithNDArray(input: java.util.List[NDArray], topK: Int):
def objectDetectWithNDArray(input: java.lang.Iterable[NDArray], topK: Int):
java.util.List[java.util.List[ObjectDetectorOutput]] = {
val ret = objDetector.objectDetectWithNDArray(convert(input.asScala.toIndexedSeq), Some(topK))
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
Expand All @@ -92,7 +92,7 @@ class ObjectDetector private[mxnet] (val objDetector: org.apache.mxnet.infer.Obj
* @param topK Number of result elements to return, sorted by probability
* @return List of list of tuples of (class, probability)
*/
def imageBatchObjectDetect(inputBatch: java.util.List[BufferedImage], topK: Int):
def imageBatchObjectDetect(inputBatch: java.lang.Iterable[BufferedImage], topK: Int):
java.util.List[java.util.List[ObjectDetectorOutput]] = {
val ret = objDetector.imageBatchObjectDetect(inputBatch.asScala, Some(topK))
(ret map {a => (a map {e => new ObjectDetectorOutput(e._1, e._2)}).asJava}).asJava
Expand Down Expand Up @@ -122,7 +122,7 @@ object ObjectDetector {
org.apache.mxnet.infer.ImageClassifier.bufferedImageToPixels(resizedImage, inputImageShape)
}

def loadInputBatch(inputImagePaths: java.util.List[String]): java.util.List[BufferedImage] = {
def loadInputBatch(inputImagePaths: java.lang.Iterable[String]): java.util.List[BufferedImage] = {
org.apache.mxnet.infer.ImageClassifier
.loadInputBatch(inputImagePaths.asScala.toList).toList.asJava
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import scala.collection.JavaConverters._

// JavaDoc description of class to be updated in https://issues.apache.org/jira/browse/MXNET-1178
class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor){
def this(modelPathPrefix: String, inputDescriptors: java.util.List[DataDesc],
contexts: java.util.List[Context], epoch: Int)
def this(modelPathPrefix: String, inputDescriptors: java.lang.Iterable[DataDesc],
contexts: java.lang.Iterable[Context], epoch: Int)
= this {
val informationDesc = JavaConverters.asScalaIteratorConverter(inputDescriptors.iterator)
.asScala.toIndexedSeq map {a => a: org.apache.mxnet.DataDesc}
Expand Down Expand Up @@ -97,10 +97,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
}

/**
* Takes input as List of one dimensional arrays and creates the NDArray needed for inference
* Takes input as List of one dimensional iterables and creates the NDArray needed for inference
* The array will be reshaped based on the input descriptors.
*
* @param input: A List of a one-dimensional array.
* @param input: A List of a one-dimensional iterables of DType Float.
An extra List is needed for when the model has more than one input.
* @return Indexed sequence array of outputs
*/
Expand All @@ -118,10 +118,10 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
* This method is useful when the input is a batch of data
* Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
*
* @param input List of NDArrays
* @param input Iterable of NDArrays
* @return Output of predictions as NDArrays
*/
def predictWithNDArray(input: java.util.List[NDArray]):
def predictWithNDArray(input: java.lang.Iterable[NDArray]):
java.util.List[NDArray] = {
val ret = predictor.predictWithNDArray(convert(JavaConverters
.asScalaIteratorConverter(input.iterator).asScala.toIndexedSeq))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class ObjectDetectorTest {

Expand Down Expand Up @@ -92,6 +94,17 @@ public void testObjectDetectorWithBatchImage() {
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithIterableOfBatchImage() {

Set<BufferedImage> batchImage = new HashSet<>();
batchImage.add(inputImage);
Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithNDArrayInput() {

Expand All @@ -103,4 +116,16 @@ public void testObjectDetectorWithNDArrayInput() {
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithIterableOfNDArrayInput() {

NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
Set<NDArray> inputL = new HashSet<>();
inputL.add(inputArr);
Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
import org.junit.Test;
import org.mockito.Mockito;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;

public class PredictorTest {

Expand Down Expand Up @@ -80,6 +78,31 @@ public void testPredictWithNDArray() {
Assert.assertEquals(expectedResult, actualOutput);
}

@Test
public void testPredictWithIterablesNDArray() {

float[] tmpArr = new float[224];
for (int y = 0; y < 224; y++)
tmpArr[y] = (int) (Math.random() * 10);

NDArray arr = new org.apache.mxnet.javaapi.NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));

Set<NDArray> inputSet = new HashSet<>();
inputSet.add(arr);

NDArray expected = new NDArray(tmpArr, new Shape(new int[] {1, 1, 1, 224}), new Context("cpu", 0));
List<NDArray> expectedResult = new ArrayList<>();
expectedResult.add(expected);

Mockito.when(mockPredictor.predictWithNDArray(inputSet)).thenReturn(expectedResult);

List<NDArray> actualOutput = mockPredictor.predictWithNDArray(inputSet);

Mockito.verify(mockPredictor, Mockito.times(1)).predictWithNDArray(inputSet);

Assert.assertEquals(expectedResult, actualOutput);
}

@Test
public void testPredictWithListOfFloatsAsInput() {
List<List<Float>> input = new ArrayList<>();
Expand Down