Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions android/src/main/java/sq/flutter/tflite/TflitePlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
Expand All @@ -43,6 +44,7 @@
import java.util.PriorityQueue;
import java.util.Vector;


public class TflitePlugin implements MethodCallHandler {
private final Registrar mRegistrar;
private Interpreter tfLite;
Expand Down Expand Up @@ -144,6 +146,16 @@ public void onMethodCall(MethodCall call, Result result) {
catch (Exception e) {
result.error("Failed to run model" , e.getMessage(), e);
}
} else if (call.method.equals("runSegmentationOnImage")) {
try {
byte[] res = runSegmentationOnImage((HashMap) call.arguments);
result.success(res);
}
catch (Exception e) {
result.error("Failed to run model" , e.getMessage(), e);
}
} else {
result.error("Invalid method", call.method.toString(), "");
}
}

Expand Down Expand Up @@ -706,6 +718,85 @@ public int compare(Map<String, Object> lhs, Map<String, Object> rhs) {
return results;
}

private byte[] runSegmentationOnImage(HashMap args) throws IOException {
String path = args.get("path").toString();
double mean = (double)(args.get("imageMean"));
float IMAGE_MEAN = (float)mean;
double std = (double)(args.get("imageStd"));
float IMAGE_STD = (float)std;
List<Long> labelColors = (ArrayList)args.get("labelColors");

long startTime = SystemClock.uptimeMillis();
ByteBuffer input = feedInputTensorImage(path, IMAGE_MEAN, IMAGE_STD);
ByteBuffer output = ByteBuffer.allocateDirect(tfLite.getOutputTensor(0).numBytes());
output.order(ByteOrder.nativeOrder());
tfLite.run(input, output);
Log.v("time", "Inference took " + (SystemClock.uptimeMillis() - startTime));

if (input.limit() == 0) throw new RuntimeException("Unexpected input position, bad file?");
if (output.position() != output.limit()) throw new RuntimeException("Unexpected output position");

output.flip();
Bitmap outputArgmax = fetchArgmax(output, labelColors);
return compressPNG(outputArgmax);
}


Bitmap fetchArgmax(ByteBuffer output, List<Long> labelColors) {
Tensor outputTensor = tfLite.getOutputTensor(0);
int outputBatchSize = outputTensor.shape()[0];
assert outputBatchSize == 1;
int outputHeight = outputTensor.shape()[1];
int outputWidth = outputTensor.shape()[2];
int outputChannels = outputTensor.shape()[3];

Bitmap outputArgmax = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888);

if (outputTensor.dataType() == DataType.FLOAT32) {
for (int i = 0; i < outputHeight; ++i) {
for (int j = 0; j < outputWidth; ++j) {
int maxIndex = 0;
float maxValue = 0.0f;
for (int c = 0; c < outputChannels; ++c) {
float outputValue = output.getFloat();
if (outputValue > maxValue) {
maxIndex = c;
maxValue = outputValue;
}
}
int labelColor = labelColors.get(maxIndex).intValue();
outputArgmax.setPixel(j, i, labelColor);
}
}
} else {
for (int i = 0; i < outputHeight; ++i) {
for (int j = 0; j < outputWidth; ++j) {
int maxIndex = 0;
int maxValue = 0;
for (int c = 0; c < outputChannels; ++c) {
int outputValue = output.get();
if (outputValue > maxValue) {
maxIndex = c;
maxValue = outputValue;
}
}
int labelColor = labelColors.get(maxIndex).intValue();
outputArgmax.setPixel(j, i, labelColor);
}
}
}
return outputArgmax;
}

byte[] compressPNG(Bitmap bitmap) {
// https://stackoverflow.com/questions/4989182/converting-java-bitmap-to-byte-array#4989543
ByteArrayOutputStream stream = new ByteArrayOutputStream();
bitmap.compress(Bitmap.CompressFormat.PNG, 100, stream);
byte[] byteArray = stream.toByteArray();
// bitmap.recycle();
return byteArray;
}

private float expit(final float x) {
return (float) (1. / (1. + Math.exp(-x)));
}
Expand Down
Binary file added example/assets/deeplabv3_257_mv_gpu.tflite
Binary file not shown.
21 changes: 21 additions & 0 deletions example/assets/deeplabv3_257_mv_gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
background
aeroplane
biyclce
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
potted plant
sheep
sofa
train
tv-monitor
166 changes: 115 additions & 51 deletions example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void main() => runApp(new App());
const String mobile = "MobileNet";
const String ssd = "SSD MobileNet";
const String yolo = "Tiny YOLOv2";
const String deeplab = "DeepLab";

class App extends StatelessWidget {
@override
Expand All @@ -31,12 +32,18 @@ class MyApp extends StatefulWidget {
class _MyAppState extends State<MyApp> {
File _image;
List _recognitions;
String _model = "";
String _model = mobile;
double _imageHeight;
double _imageWidth;

Future getImage() async {
Future predictImagePicker() async {
var image = await ImagePicker.pickImage(source: ImageSource.gallery);
if (image == null) return;
predictImage(image);
}

Future predictImage(File image) async {
if (image == null) return;

switch (_model) {
case yolo:
Expand All @@ -45,6 +52,9 @@ class _MyAppState extends State<MyApp> {
case ssd:
ssdMobileNet(image);
break;
case deeplab:
segmentMobileNet(image);
break;
default:
recognizeImage(image);
// recognizeImageBinary(image);
Expand All @@ -67,6 +77,7 @@ class _MyAppState extends State<MyApp> {
@override
void initState() {
super.initState();
loadModel();
}

Future loadModel() async {
Expand All @@ -84,6 +95,11 @@ class _MyAppState extends State<MyApp> {
model: "assets/ssd_mobilenet.tflite",
labels: "assets/ssd_mobilenet.txt");
break;
case deeplab:
res = await Tflite.loadModel(
model: "assets/deeplabv3_257_mv_gpu.tflite",
labels: "assets/deeplabv3_257_mv_gpu.txt");
break;
default:
res = await Tflite.loadModel(
model: "assets/mobilenet_v1_1.0_224.tflite",
Expand Down Expand Up @@ -194,11 +210,25 @@ class _MyAppState extends State<MyApp> {
});
}

onSelect(model) {
Future segmentMobileNet(File image) async {
var recognitions = await Tflite.runSegmentationOnImage(
path: image.path,
imageMean: 127.5,
imageStd: 127.5,
);

setState(() {
_recognitions = recognitions;
});
}

onSelect(model) async {
setState(() {
_model = model;
_recognitions = null;
});
loadModel();
await loadModel();
predictImage(_image);
}

List<Widget> renderBoxes(Size screen) {
Expand All @@ -214,6 +244,7 @@ class _MyAppState extends State<MyApp> {
height: re["rect"]["h"] * factorY,
child: Container(
decoration: BoxDecoration(
borderRadius: BorderRadius.all(Radius.circular(8.0)),
border: Border.all(
color: blue,
width: 2,
Expand All @@ -235,58 +266,91 @@ class _MyAppState extends State<MyApp> {
@override
Widget build(BuildContext context) {
Size size = MediaQuery.of(context).size;
List<Widget> stackChildren = [];

if (_model == deeplab && _recognitions != null) {
stackChildren.add(Positioned(
top: 0.0,
left: 0.0,
width: size.width,
child: _image == null
? Text('No image selected.')
: Container(
decoration: BoxDecoration(
image: DecorationImage(
alignment: Alignment.topCenter,
image: MemoryImage(_recognitions),
fit: BoxFit.fill)),
child: Opacity(opacity: 0.3, child: Image.file(_image))),
));
} else {
stackChildren.add(Positioned(
top: 0.0,
left: 0.0,
width: size.width,
child: _image == null ? Text('No image selected.') : Image.file(_image),
));
}

if (_model == mobile) {
stackChildren.add(Center(
child: Column(
children: _recognitions != null
? _recognitions.map((res) {
return Text(
"${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}",
style: TextStyle(
color: Colors.black,
fontSize: 20.0,
background: Paint()..color = Colors.white,
),
);
}).toList()
: [],
),
));
} else if (_model == ssd || _model == yolo) {
stackChildren.addAll(renderBoxes(size));
}

return Scaffold(
appBar: AppBar(
title: const Text('tflite example app'),
),
body: _model == ""
? Center(
child: Column(
children: <Widget>[
RaisedButton(
child: const Text(mobile),
onPressed: () => onSelect(mobile),
),
RaisedButton(
child: const Text(ssd),
onPressed: () => onSelect(ssd),
),
RaisedButton(
child: const Text(yolo),
onPressed: () => onSelect(yolo),
),
],
),
)
: Stack(
children: <Widget>[
Container(
child: _image == null
? Text('No image selected.')
: Image.file(_image),
actions: <Widget>[
PopupMenuButton<String>(
onSelected: onSelect,
itemBuilder: (context) {
List<PopupMenuEntry<String>> menuEntries = [
const PopupMenuItem<String>(
child: Text(mobile),
value: mobile,
),
_model == mobile
? Center(
child: Column(
children: _recognitions != null
? _recognitions.map((res) {
return Text(
"${res["index"]} - ${res["label"]}: ${res["confidence"].toString()}",
style: TextStyle(
color: Colors.black,
fontSize: 20.0,
background: Paint()..color = Colors.white,
),
);
}).toList()
: [],
),
)
: Stack(children: renderBoxes(size)),
],
),
const PopupMenuItem<String>(
child: Text(ssd),
value: ssd,
),
const PopupMenuItem<String>(
child: Text(yolo),
value: yolo,
),
];

if (Platform.isAndroid) {
menuEntries.add(const PopupMenuItem<String>(
child: Text(deeplab),
value: deeplab,
));
}
return menuEntries;
},
)
],
),
body: Stack(
children: stackChildren,
),
floatingActionButton: FloatingActionButton(
onPressed: getImage,
onPressed: predictImagePicker,
tooltip: 'Pick Image',
child: Icon(Icons.image),
),
Expand Down
3 changes: 3 additions & 0 deletions example/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ flutter:
- assets/yolov2_tiny.txt
- assets/ssd_mobilenet.tflite
- assets/ssd_mobilenet.txt
- assets/deeplabv3_257_mv_gpu.tflite
- assets/deeplabv3_257_mv_gpu.txt


# An image asset can refer to one or more resolution-specific "variants", see
# https://flutter.io/assets-and-images/#resolution-aware.
Expand Down
Loading