Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CameraX and Data Binding in Object Detection App #341

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
37 changes: 30 additions & 7 deletions lite/examples/object_detection/android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ android {
dimension "tfliteInference"
}
}

buildFeatures {
dataBinding = true
}
}

// import DownloadModels task
Expand All @@ -54,12 +58,31 @@ dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar','*.aar'])
interpreterImplementation project(":lib_interpreter")
taskApiImplementation project(":lib_task_api")
implementation 'androidx.appcompat:appcompat:1.0.0'
implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0'
implementation 'com.google.android.material:material:1.0.0'
implementation 'androidx.appcompat:appcompat:1.3.1'
implementation 'androidx.coordinatorlayout:coordinatorlayout:1.1.0'
implementation 'com.google.android.material:material:1.4.0'

androidTestImplementation 'androidx.test.ext:junit:1.1.3'
androidTestImplementation 'com.google.truth:truth:1.1.3'
androidTestImplementation 'androidx.test:runner:1.4.0'
androidTestImplementation 'androidx.test:rules:1.4.0'

// Required -- JUnit 4 framework
androidTestImplementation("junit:junit:4.13.2")
// Optional -- Robolectric environment
androidTestImplementation('androidx.test:core:1.4.0')
// Optional -- Mockito framework
androidTestImplementation("org.mockito:mockito-core:3.11.2")
implementation 'net.bytebuddy:byte-buddy-android-test:1.11.12'
implementation 'org.tensorflow:tensorflow-lite-support:0.2.0'
implementation 'org.mockito:mockito-android:3.11.2'

androidTestImplementation 'androidx.test.ext:junit:1.1.1'
androidTestImplementation 'com.google.truth:truth:1.0.1'
androidTestImplementation 'androidx.test:runner:1.2.0'
androidTestImplementation 'androidx.test:rules:1.1.0'
// CameraX dependencies
def camerax_version = "1.0.1"
// CameraX core library using camera2 implementation
implementation "androidx.camera:camera-camera2:$camerax_version"
// CameraX Lifecycle Library
implementation "androidx.camera:camera-lifecycle:$camerax_version"
// CameraX View class
implementation "androidx.camera:camera-view:1.0.0-alpha27"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dining_table 27.492085 97.94615 623.1435 444.8627 0.48828125
knife 342.53433 243.71082 583.89185 416.34595 0.4765625
cup 68.025925 197.5857 202.02031 374.2206 0.4375
book 185.43098 139.64153 244.51149 203.37737 0.3125
knife 345.29675 242.38895 585.65424 415.0241 0.54
dining_table 24.836613 95.182755 620.488 447.6261 0.55
wine_glass 63.532368 202.38976 204.03336 387.60184 0.51
book 186.38379 138.98523 242.53781 205.76802 0.32
Original file line number Diff line number Diff line change
Expand Up @@ -17,87 +17,77 @@
package org.tensorflow.lite.examples.detection;

Copy link
Contributor

@lintian06 lintian06 Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't delete the test, but make it work. We need to test the change and make sure it works as intended.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add this asap!

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.openMocks;
import static java.lang.Math.abs;
import static java.lang.Math.max;
import static java.lang.Math.min;

import android.content.Context;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.BitmapFactory;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.ImageFormat;
import android.graphics.RectF;
import android.util.Size;
import android.media.Image;

import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.platform.app.InstrumentationRegistry;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.tensorflow.lite.examples.detection.env.ImageUtils;
import org.tensorflow.lite.examples.detection.tflite.Detector;
import org.tensorflow.lite.examples.detection.tflite.Detector.Recognition;
import org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel;
import org.tensorflow.lite.support.image.ColorSpaceType;

/** Golden test for Object Detection Reference app. */
/**
* Golden test for Object Detection Reference app.
*/
@RunWith(AndroidJUnit4.class)
public class DetectorTest {

private static final int MODEL_INPUT_SIZE = 300;
private static final boolean IS_MODEL_QUANTIZED = true;
private static final String MODEL_FILE = "detect.tflite";
private static final String LABELS_FILE = "labelmap.txt";
private static final Size IMAGE_SIZE = new Size(640, 480);

private Detector detector;
private Bitmap croppedBitmap;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private final Context applicationContext = ApplicationProvider.getApplicationContext();

@Before
public void setUp() throws IOException {
openMocks(this);
detector =
TFLiteObjectDetectionAPIModel.create(
InstrumentationRegistry.getInstrumentation().getContext(),
applicationContext,
MODEL_FILE,
LABELS_FILE,
MODEL_INPUT_SIZE,
IS_MODEL_QUANTIZED);
int cropSize = MODEL_INPUT_SIZE;
int previewWidth = IMAGE_SIZE.getWidth();
int previewHeight = IMAGE_SIZE.getHeight();
int sensorOrientation = 0;
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);

frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
cropSize, cropSize,
sensorOrientation, false);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
}

@Test
public void detectionResultsShouldNotChange() throws Exception {
Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(loadImage("table.jpg"), frameToCropTransform, null);
final List<Recognition> results = detector.recognizeImage(croppedBitmap);
Bitmap assetsBitmap = loadImage("table.jpg");
final List<Recognition> results = detector.recognizeImage(mockMediaImageFromBitmap(assetsBitmap, ColorSpaceType.NV21), 0);
final List<Recognition> expected = loadRecognitions("table_results.txt");

for (Recognition target : expected) {
// Find a matching result in results
boolean matched = false;
for (Recognition item : results) {
RectF bbox = new RectF();
cropToFrameTransform.mapRect(bbox, item.getLocation());
if (item.getTitle().equals(target.getTitle())
&& matchBoundingBoxes(bbox, target.getLocation())
&& matchBoundingBoxes(item.getLocation(), target.getLocation())
&& matchConfidence(item.getConfidence(), target.getConfidence())) {
matched = true;
break;
Expand Down Expand Up @@ -135,7 +125,7 @@ private static Bitmap loadImage(String fileName) throws Exception {
// category bbox.left bbox.top bbox.right bbox.bottom confidence
// ...
// Example:
// Apple 99 25 30 75 80 0.99
// Apple 99 25 30 75 0.99
// Banana 25 90 75 200 0.98
// ...
private static List<Recognition> loadRecognitions(String fileName) throws Exception {
Expand All @@ -161,4 +151,137 @@ private static List<Recognition> loadRecognitions(String fileName) throws Except
}
return result;
}

private static YuvPlaneInfo createYuvPlaneInfo(
ColorSpaceType colorSpaceType, int width, int height) {
int uIndex = 0;
int vIndex = 0;
int uvPixelStride = 0;
int yBufferSize = width * height;
int uvBufferSize = ((width + 1) / 2) * ((height + 1) / 2);
int uvRowStride = 0;
switch (colorSpaceType) {
case NV12:
uIndex = yBufferSize;
vIndex = yBufferSize + 1;
uvPixelStride = 2;
uvRowStride = (width + 1) / 2 * 2;
break;
case NV21:
vIndex = yBufferSize;
uIndex = yBufferSize + 1;
uvPixelStride = 2;
uvRowStride = (width + 1) / 2 * 2;
break;
case YV12:
vIndex = yBufferSize;
uIndex = yBufferSize + uvBufferSize;
uvPixelStride = 1;
uvRowStride = (width + 1) / 2;
break;
case YV21:
uIndex = yBufferSize;
vIndex = yBufferSize + uvBufferSize;
uvPixelStride = 1;
uvRowStride = (width + 1) / 2;
break;
default:
throw new IllegalArgumentException(
"ColorSpaceType: " + colorSpaceType.name() + ", is unsupported.");
}

return YuvPlaneInfo.create(
uIndex,
vIndex,
/*yRowStride=*/ width,
uvRowStride,
uvPixelStride,
yBufferSize,
uvBufferSize);
}

private static byte[] getYuvBytesFromBitmap(Bitmap bitmap, ColorSpaceType colorSpaceType) {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
int[] rgb = new int[width * height];
bitmap.getPixels(rgb, 0, width, 0, 0, width, height);

YuvPlaneInfo yuvPlaneInfo = createYuvPlaneInfo(colorSpaceType, width, height);

byte[] yuv = new byte[yuvPlaneInfo.getYBufferSize() + yuvPlaneInfo.getUvBufferSize() * 2];
int rgbIndex = 0;
int yIndex = 0;
int vIndex = yuvPlaneInfo.getVIndex();
int uIndex = yuvPlaneInfo.getUIndex();
int uvPixelStride = yuvPlaneInfo.getUvPixelStride();

for (int j = 0; j < height; ++j) {
for (int i = 0; i < width; ++i) {
int r = (rgb[rgbIndex] >> 16) & 0xff;
int g = (rgb[rgbIndex] >> 8) & 0xff;
int b = rgb[rgbIndex] & 0xff;

int y = (int) (0.299f * r + 0.587f * g + 0.114f * b);
int v = (int) ((r - y) * 0.713f + 128);
int u = (int) ((b - y) * 0.564f + 128);

yuv[yIndex++] = (byte) max(0, min(255, y));
byte uByte = (byte) max(0, min(255, u));
byte vByte = (byte) max(0, min(255, v));

if ((i & 0x01) == 0 && (j & 0x01) == 0) {
yuv[vIndex] = vByte;
yuv[uIndex] = uByte;
vIndex += uvPixelStride;
uIndex += uvPixelStride;
}

rgbIndex++;
}
}
return yuv;
}

public static Image mockMediaImageFromBitmap(Bitmap bitmap, ColorSpaceType colorSpaceType) {
// Converts the RGB Bitmap to YUV TensorBuffer
byte[] yuv = getYuvBytesFromBitmap(bitmap, colorSpaceType);

int width = bitmap.getWidth();
int height = bitmap.getHeight();
YuvPlaneInfo yuvPlaneInfo = createYuvPlaneInfo(colorSpaceType, width, height);

ByteBuffer yuvBuffer =
ByteBuffer.allocateDirect(
yuvPlaneInfo.getYBufferSize() + yuvPlaneInfo.getUvBufferSize() * 2);
yuvBuffer.put(yuv);
yuvBuffer.rewind();
ByteBuffer yPlane = yuvBuffer.slice();

yuvBuffer.rewind();
yuvBuffer.position(yuvPlaneInfo.getUIndex());
ByteBuffer uPlane = yuvBuffer.slice();

yuvBuffer.rewind();
yuvBuffer.position(yuvPlaneInfo.getVIndex());
ByteBuffer vPlane = yuvBuffer.slice();

Image.Plane mockPlaneY = mock(Image.Plane.class);
when(mockPlaneY.getBuffer()).thenReturn(yPlane);
when(mockPlaneY.getRowStride()).thenReturn(yuvPlaneInfo.getYRowStride());
Image.Plane mockPlaneU = mock(Image.Plane.class);
when(mockPlaneU.getBuffer()).thenReturn(uPlane);
when(mockPlaneU.getRowStride()).thenReturn(yuvPlaneInfo.getUvRowStride());
when(mockPlaneU.getPixelStride()).thenReturn(yuvPlaneInfo.getUvPixelStride());
Image.Plane mockPlaneV = mock(Image.Plane.class);
when(mockPlaneV.getBuffer()).thenReturn(vPlane);
when(mockPlaneV.getRowStride()).thenReturn(yuvPlaneInfo.getUvRowStride());
when(mockPlaneV.getPixelStride()).thenReturn(yuvPlaneInfo.getUvPixelStride());

Image imageMock = mock(Image.class);
when(imageMock.getFormat()).thenReturn(ImageFormat.YUV_420_888);
when(imageMock.getPlanes()).thenReturn(new Image.Plane[]{mockPlaneY, mockPlaneU, mockPlaneV});
when(imageMock.getWidth()).thenReturn(width);
when(imageMock.getHeight()).thenReturn(height);
return imageMock;
}
}
Loading