Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental] Use blob for decode result instead of base64 #1

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ PODS:
- React-jsinspector (0.71.8)
- React-logger (0.71.8):
- glog
- react-native-blob-jsi-helper (0.3.0):
- React
- React-Core
- react-native-gcanvas (6.0.3):
- GCanvas
- React
Expand Down Expand Up @@ -482,6 +485,7 @@ DEPENDENCIES:
- React-jsiexecutor (from `../node_modules/react-native/ReactCommon/jsiexecutor`)
- React-jsinspector (from `../node_modules/react-native/ReactCommon/jsinspector`)
- React-logger (from `../node_modules/react-native/ReactCommon/logger`)
- react-native-blob-jsi-helper (from `../node_modules/react-native-blob-jsi-helper`)
- "react-native-gcanvas (from `../node_modules/@flyskywhy/react-native-gcanvas`)"
- react-native-image-picker (from `../node_modules/react-native-image-picker`)
- react-native-quick-base64 (from `../node_modules/react-native-quick-base64`)
Expand Down Expand Up @@ -515,7 +519,6 @@ SPEC REPOS:
- Flipper-RSocket
- FlipperKit
- fmt
- GCanvas
- libevent
- onnxruntime-c
- OpenSSL-Universal
Expand Down Expand Up @@ -567,6 +570,8 @@ EXTERNAL SOURCES:
:path: "../node_modules/react-native/ReactCommon/jsinspector"
React-logger:
:path: "../node_modules/react-native/ReactCommon/logger"
react-native-blob-jsi-helper:
:path: "../node_modules/react-native-blob-jsi-helper"
react-native-gcanvas:
:path: "../node_modules/@flyskywhy/react-native-gcanvas"
react-native-image-picker:
Expand Down Expand Up @@ -643,6 +648,7 @@ SPEC CHECKSUMS:
React-jsiexecutor: 747911ab5921641b4ed7e4900065896597142125
React-jsinspector: c712f9e3bb9ba4122d6b82b4f906448b8a281580
React-logger: 342f358b8decfbf8f272367f4eacf4b6154061be
react-native-blob-jsi-helper: 0f650a3c8af9b44d379d38b50733b75e9eafee23
react-native-gcanvas: f333990fb1593272cd66c70275099cdac9e33821
react-native-image-picker: ec9b713e248760bfa0f879f0715391de4651a7cb
react-native-quick-base64: 62290829c619fbabca4c41cfec75ae759d08fc1c
Expand All @@ -665,6 +671,6 @@ SPEC CHECKSUMS:
Yoga: 065f0b74dba4832d6e328238de46eb72c5de9556
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

PODFILE CHECKSUM: 6ffd45449d1e1316675abfb877476971a9009f56
PODFILE CHECKSUM: b10522d68aadff0af5d00cb3edc5a2b89244d56c

COCOAPODS: 1.11.3
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"path-browserify": "^1.0.1",
"react": "18.2.0",
"react-native": "0.71.8",
"react-native-blob-jsi-helper": "^0.3.0",
"react-native-fs": "^2.20.0",
"react-native-image-picker": "^5.3.1",
"react-native-quick-base64": "^2.0.6",
Expand Down
167 changes: 165 additions & 2 deletions patches/onnxruntime-react-native+1.14.0.patch
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,75 @@ index 4c8a318..65b58c1 100644
+ implementation project(":onnxruntime-patched")
+
}
diff --git a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
index fe59cef..41c1dd2 100644
--- a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
+++ b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
@@ -39,6 +39,8 @@ import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

+import com.facebook.react.modules.blob.BlobModule;
+
@RequiresApi(api = Build.VERSION_CODES.N)
public class OnnxruntimeModule extends ReactContextBaseJavaModule {
private static ReactApplicationContext reactContext;
@@ -165,6 +167,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
throw new Exception("Model is not loaded: " + key);
}

+ BlobModule blobModule = reactContext.getNativeModule(BlobModule.class);
+
RunOptions runOptions = parseRunOptions(options);

long startTime = System.currentTimeMillis();
@@ -217,7 +221,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule {
Log.d("Duration", "inference: " + duration);

startTime = System.currentTimeMillis();
- WritableMap resultMap = TensorHelper.createOutputTensor(result);
+ WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result);
duration = System.currentTimeMillis() - startTime;
Log.d("Duration", "createOutputTensor: " + duration);

diff --git a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java
index 500141a..49b3abd 100644
index 500141a..20c680f 100644
--- a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java
+++ b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java
@@ -164,7 +164,11 @@ public class TensorHelper {
@@ -29,6 +29,8 @@ import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

+import com.facebook.react.modules.blob.BlobModule;
+
public class TensorHelper {
/**
* Supported tensor data type
@@ -80,7 +82,7 @@ public class TensorHelper {
* It creates an output map from an output tensor.
* a data array is encoded as base64 string.
*/
- public static WritableMap createOutputTensor(OrtSession.Result result) throws Exception {
+ public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception {
WritableMap outputTensorMap = Arguments.createMap();

Iterator<Map.Entry<String, OnnxValue>> iterator = result.iterator();
@@ -115,8 +117,12 @@ public class TensorHelper {
}
outputTensor.putArray("data", dataArray);
} else {
- String data = createOutputTensor(onnxTensor);
- outputTensor.putString("data", data);
+ // Blob
+ byte[] bufferArray = createOutputTensor(onnxTensor);
+ String blobId = blobModule.store(bufferArray);
+ int size = bufferArray.length;
+ outputTensor.putString("data", blobId);
+ outputTensor.putInt("size", size);
}

outputTensorMap.putMap(outputName, outputTensor);
@@ -164,7 +170,11 @@ public class TensorHelper {
tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8);
break;
}
Expand All @@ -29,3 +93,102 @@ index 500141a..49b3abd 100644
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
@@ -177,7 +187,7 @@ public class TensorHelper {
return tensor;
}

- private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception {
+ private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception {
TensorInfo tensorInfo = onnxTensor.getInfo();
ByteBuffer buffer = null;

@@ -224,8 +234,7 @@ public class TensorHelper {
throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString());
}

- String data = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
- return data;
+ return buffer.array();
}

private static final Map<String, TensorInfo.OnnxTensorType> JsTensorTypeToOnnxTensorTypeMap =
diff --git a/node_modules/onnxruntime-react-native/ios/TensorHelper.mm b/node_modules/onnxruntime-react-native/ios/TensorHelper.mm
index 00c1c79..ed6c81c 100644
--- a/node_modules/onnxruntime-react-native/ios/TensorHelper.mm
+++ b/node_modules/onnxruntime-react-native/ios/TensorHelper.mm
@@ -2,6 +2,8 @@
// Licensed under the MIT License.

#import "TensorHelper.h"
+#import <React/RCTBlobManager.h>
+#import <React/RCTBridge+Private.h>
#import <Foundation/Foundation.h>

@implementation TensorHelper
@@ -109,8 +111,11 @@ + (NSDictionary *)createOutputTensor:(const std::vector<const char *> &)outputNa
}
outputTensor[@"data"] = buffer;
} else {
- NSString *data = [self createOutputTensor:value];
- outputTensor[@"data"] = data;
+ NSData *buffer = [self createOutputTensor:value];
+ RCTBlobManager* blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class];
+ NSString* blobId = [blobManager store:buffer];
+ outputTensor[@"data"] = blobId;
+ outputTensor[@"size"] = [NSNumber numberWithUnsignedInteger:buffer.length];
}

outputTensorMap[[NSString stringWithUTF8String:outputName]] = outputTensor;
@@ -170,15 +175,15 @@ + (NSDictionary *)createOutputTensor:(const std::vector<const char *> &)outputNa
}
}

-template <typename T> static NSString *createOutputTensorT(const Ort::Value &tensor) {
+template <typename T> static NSData *createOutputTensorT(const Ort::Value &tensor) {
const auto data = tensor.GetTensorData<T>();
NSData *buffer = [NSData dataWithBytesNoCopy:(void *)data
length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T)
freeWhenDone:false];
- return [buffer base64EncodedStringWithOptions:0];
+ return buffer;
}

-+ (NSString *)createOutputTensor:(const Ort::Value &)tensor {
++ (NSData *)createOutputTensor:(const Ort::Value &)tensor {
ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType();

switch (tensorType) {
diff --git a/node_modules/onnxruntime-react-native/lib/backend.ts b/node_modules/onnxruntime-react-native/lib/backend.ts
index 4ebc364..7aee5a0 100644
--- a/node_modules/onnxruntime-react-native/lib/backend.ts
+++ b/node_modules/onnxruntime-react-native/lib/backend.ts
@@ -4,6 +4,7 @@
import {Buffer} from 'buffer';
import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common';
import {Platform} from 'react-native';
+import {getArrayBufferForBlob} from 'react-native-blob-jsi-helper';

import {binding, Binding} from './binding';

@@ -98,7 +99,20 @@ class OnnxruntimeSessionHandler implements SessionHandler {
}
}
const input = this.encodeFeedsType(feeds);
- const results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options);
+ let results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options);
+ results = Object.entries(results).reduce((acc, [name, result]) => {
+ acc[name] = {
+ ...result,
+ data: getArrayBufferForBlob({
+ _data: {
+ blobId: result.data,
+ offset: 0,
+ size: result.size,
+ }
+ }),
+ };
+ return acc;
+ }, {})
const output = this.decodeReturnType(results);
return output;
}
5 changes: 5 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6595,6 +6595,11 @@ react-is@^17.0.1:
resolved "https://registry.yarnpkg.com/react-is/-/react-is-17.0.2.tgz#e691d4a8e9c789365655539ab372762b0efb54f0"
integrity sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==

react-native-blob-jsi-helper@^0.3.0:
version "0.3.0"
resolved "https://registry.yarnpkg.com/react-native-blob-jsi-helper/-/react-native-blob-jsi-helper-0.3.0.tgz#a57a8467d9b08d620db1d9e546dbbef45e2996d2"
integrity sha512-9ez/zdiHEcuI86ufxSAWqiPEMjhtCW89DHlG3nVPhQ1vBi7cb7/jsrMYILVaNzGsxsW7vPPcMAs9Cd8hxo7M0w==

react-native-codegen@^0.71.5:
version "0.71.5"
resolved "https://registry.yarnpkg.com/react-native-codegen/-/react-native-codegen-0.71.5.tgz#454a42a891cd4ca5fc436440d301044dc1349c14"
Expand Down