Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Addressed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
piyushghai committed Jan 8, 2019
1 parent 4d95fed commit 41f3371
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,43 @@

package org.apache.mxnet.infer.javaapi;

import org.junit.Assert;
import org.junit.Test;

public class ObjectDetectorOutputTest {

private String predictedClassName = "lion";

private float delta = 0.00001f;

@Test
public void testConstructor() {

float[] arr = new float[]{0f, 1f, 2f, 3f, 4f};

ObjectDetectorOutput odOutput = new ObjectDetectorOutput("simba", arr);
ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr);

Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);

assert (odOutput.getClassName().equals("simba"));
assert (odOutput.getProbability() == 0);
assert (odOutput.getXMin() == 1);
assert (odOutput.getXMax() == 2);
assert (odOutput.getYMin() == 3);
assert (odOutput.getYMax() == 4);
}

@Test (expected = ArrayIndexOutOfBoundsException.class)
public void testIncompleteArgsConstructor() {

float[] arr = new float[]{0f, 1f};

ObjectDetectorOutput odOutput = new ObjectDetectorOutput("simba", arr);
ObjectDetectorOutput odOutput = new ObjectDetectorOutput(predictedClassName, arr);

Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);

odOutput.getYMax();
// This is where exception will be thrown
odOutput.getXMax();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,51 @@

public class ObjectDetectorTest {

List<DataDesc> inputDesc;
BufferedImage inputImage;
private List<DataDesc> inputDesc;
private BufferedImage inputImage;

List<List<ObjectDetectorOutput>> result;
private List<List<ObjectDetectorOutput>> expectedResult;

ObjectDetector objectDetector;
private ObjectDetector objectDetector;

private int batchSize = 1;

private int channels = 3;

private int imageHeight = 512;

private int imageWidth = 512;

private String dataName = "data";

private int topK = 5;

private String predictedClassName = "lion"; // Random string

private Shape getTestShape() {

return new Shape(new int[] {batchSize, channels, imageHeight, imageWidth});
}

@Before
public void setUp() {

inputDesc = new ArrayList<>();
inputDesc.add(new DataDesc("", new Shape(new int[]{1, 3, 512, 512}), DType.Float32(), Layout.NCHW()));
inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB);
inputDesc.add(new DataDesc(dataName, getTestShape(), DType.Float32(), Layout.NCHW()));
inputImage = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB);
objectDetector = Mockito.mock(ObjectDetector.class);
result = new ArrayList<>();
result.add(new ArrayList<ObjectDetectorOutput>());
result.get(0).add(new ObjectDetectorOutput("simbaa", new float[]{}));
expectedResult = new ArrayList<>();
expectedResult.add(new ArrayList<ObjectDetectorOutput>());
expectedResult.get(0).add(new ObjectDetectorOutput(predictedClassName, new float[]{}));
}

@Test
public void testObjectDetectorWithInputImage() {

Mockito.when(objectDetector.imageObjectDetect(inputImage, 5)).thenReturn(result);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageObjectDetect(inputImage, 5);
Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, 5);
Assert.assertEquals(result, actualResult);
Mockito.when(objectDetector.imageObjectDetect(inputImage, topK)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageObjectDetect(inputImage, topK);
Mockito.verify(objectDetector, Mockito.times(1)).imageObjectDetect(inputImage, topK);
Assert.assertEquals(expectedResult, actualResult);
}


Expand All @@ -67,21 +86,21 @@ public void testObjectDetectorWithBatchImage() {

List<BufferedImage> batchImage = new ArrayList<>();
batchImage.add(inputImage);
Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, 5)).thenReturn(result);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, 5);
Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, 5);
Assert.assertEquals(result, actualResult);
Mockito.when(objectDetector.imageBatchObjectDetect(batchImage, topK)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.imageBatchObjectDetect(batchImage, topK);
Mockito.verify(objectDetector, Mockito.times(1)).imageBatchObjectDetect(batchImage, topK);
Assert.assertEquals(expectedResult, actualResult);
}

@Test
public void testObjectDetectorWithNDArrayInput() {

NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, new Shape(new int[] {1, 3, 512, 512}));
NDArray inputArr = ObjectDetector.bufferedImageToPixels(inputImage, getTestShape());
List<NDArray> inputL = new ArrayList<>();
inputL.add(inputArr);
Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(result);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, 5);
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, 5);
Assert.assertEquals(result, actualResult);
Mockito.when(objectDetector.objectDetectWithNDArray(inputL, 5)).thenReturn(expectedResult);
List<List<ObjectDetectorOutput>> actualResult = objectDetector.objectDetectWithNDArray(inputL, topK);
Mockito.verify(objectDetector, Mockito.times(1)).objectDetectWithNDArray(inputL, topK);
Assert.assertEquals(expectedResult, actualResult);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void setUp() {
}

@Test
public void testPredictWithFloatArry() {
public void testPredictWithFloatArray() {

float tmp[][] = new float[1][224];
for (int x = 0; x < 1; x++) {
Expand All @@ -55,23 +55,6 @@ public void testPredictWithFloatArry() {
Assert.assertArrayEquals(expectedResult, actualResult);
}

@Test
public void testPredictWithDoubleArry() {

double tmp[][] = new double[1][224];
for (int x = 0; x < 1; x++) {
for (int y = 0; y < 224; y++)
tmp[x][y] = (int) (Math.random() * 10);
}

double [][] expectedResult = new double[][] {{1d, 2d}};
Mockito.when(mockPredictor.predict(tmp)).thenReturn(expectedResult);
double[][] actualResult = mockPredictor.predict(tmp);

Mockito.verify(mockPredictor, Mockito.times(1)).predict(tmp);
Assert.assertArrayEquals(expectedResult, actualResult);
}

@Test
public void testPredictWithNDArray() {

Expand Down

0 comments on commit 41f3371

Please sign in to comment.