diff --git a/ios/Podfile.lock b/ios/Podfile.lock index 7d8bede..61a265a 100644 --- a/ios/Podfile.lock +++ b/ios/Podfile.lock @@ -334,6 +334,9 @@ PODS: - React-jsinspector (0.71.8) - React-logger (0.71.8): - glog + - react-native-blob-jsi-helper (0.3.0): + - React + - React-Core - react-native-gcanvas (6.0.3): - GCanvas - React @@ -482,6 +485,7 @@ DEPENDENCIES: - React-jsiexecutor (from `../node_modules/react-native/ReactCommon/jsiexecutor`) - React-jsinspector (from `../node_modules/react-native/ReactCommon/jsinspector`) - React-logger (from `../node_modules/react-native/ReactCommon/logger`) + - react-native-blob-jsi-helper (from `../node_modules/react-native-blob-jsi-helper`) - "react-native-gcanvas (from `../node_modules/@flyskywhy/react-native-gcanvas`)" - react-native-image-picker (from `../node_modules/react-native-image-picker`) - react-native-quick-base64 (from `../node_modules/react-native-quick-base64`) @@ -515,7 +519,6 @@ SPEC REPOS: - Flipper-RSocket - FlipperKit - fmt - - GCanvas - libevent - onnxruntime-c - OpenSSL-Universal @@ -567,6 +570,8 @@ EXTERNAL SOURCES: :path: "../node_modules/react-native/ReactCommon/jsinspector" React-logger: :path: "../node_modules/react-native/ReactCommon/logger" + react-native-blob-jsi-helper: + :path: "../node_modules/react-native-blob-jsi-helper" react-native-gcanvas: :path: "../node_modules/@flyskywhy/react-native-gcanvas" react-native-image-picker: @@ -643,6 +648,7 @@ SPEC CHECKSUMS: React-jsiexecutor: 747911ab5921641b4ed7e4900065896597142125 React-jsinspector: c712f9e3bb9ba4122d6b82b4f906448b8a281580 React-logger: 342f358b8decfbf8f272367f4eacf4b6154061be + react-native-blob-jsi-helper: 0f650a3c8af9b44d379d38b50733b75e9eafee23 react-native-gcanvas: f333990fb1593272cd66c70275099cdac9e33821 react-native-image-picker: ec9b713e248760bfa0f879f0715391de4651a7cb react-native-quick-base64: 62290829c619fbabca4c41cfec75ae759d08fc1c @@ -665,6 +671,6 @@ SPEC CHECKSUMS: Yoga: 065f0b74dba4832d6e328238de46eb72c5de9556 YogaKit: f782866e155069a2cca2517aafea43200b01fd5a -PODFILE CHECKSUM: 6ffd45449d1e1316675abfb877476971a9009f56 +PODFILE CHECKSUM: b10522d68aadff0af5d00cb3edc5a2b89244d56c COCOAPODS: 1.11.3 diff --git a/package.json b/package.json index dfee9a6..00be649 100644 --- a/package.json +++ b/package.json @@ -25,6 +25,7 @@ "path-browserify": "^1.0.1", "react": "18.2.0", "react-native": "0.71.8", + "react-native-blob-jsi-helper": "^0.3.0", "react-native-fs": "^2.20.0", "react-native-image-picker": "^5.3.1", "react-native-quick-base64": "^2.0.6", diff --git a/patches/onnxruntime-react-native+1.14.0.patch b/patches/onnxruntime-react-native+1.14.0.patch index 71cc993..4a35452 100644 --- a/patches/onnxruntime-react-native+1.14.0.patch +++ b/patches/onnxruntime-react-native+1.14.0.patch @@ -12,11 +12,75 @@ index 4c8a318..65b58c1 100644 + implementation project(":onnxruntime-patched") + } +diff --git a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +index fe59cef..41c1dd2 100644 +--- a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java ++++ b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java +@@ -39,6 +39,8 @@ import java.util.Set; + import java.util.stream.Collectors; + import java.util.stream.Stream; + ++import com.facebook.react.modules.blob.BlobModule; ++ + @RequiresApi(api = Build.VERSION_CODES.N) + public class OnnxruntimeModule extends ReactContextBaseJavaModule { + private static ReactApplicationContext reactContext; +@@ -165,6 +167,8 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { + throw new Exception("Model is not loaded: " + key); + } + ++ BlobModule blobModule = reactContext.getNativeModule(BlobModule.class); ++ + RunOptions runOptions = parseRunOptions(options); + + long startTime = System.currentTimeMillis(); +@@ -217,7 +221,7 @@ public class OnnxruntimeModule extends ReactContextBaseJavaModule { + Log.d("Duration", "inference: " + duration); + + startTime = System.currentTimeMillis(); +- WritableMap resultMap = TensorHelper.createOutputTensor(result); ++ WritableMap resultMap = TensorHelper.createOutputTensor(blobModule, result); + duration = System.currentTimeMillis() - startTime; + Log.d("Duration", "createOutputTensor: " + duration); + diff --git a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java -index 500141a..49b3abd 100644 +index 500141a..20c680f 100644 --- a/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java +++ b/node_modules/onnxruntime-react-native/android/src/main/java/ai/onnxruntime/reactnative/TensorHelper.java -@@ -164,7 +164,11 @@ public class TensorHelper { +@@ -29,6 +29,8 @@ import java.util.Objects; + import java.util.stream.Collectors; + import java.util.stream.Stream; + ++import com.facebook.react.modules.blob.BlobModule; ++ + public class TensorHelper { + /** + * Supported tensor data type +@@ -80,7 +82,7 @@ public class TensorHelper { + * It creates an output map from an output tensor. + * a data array is encoded as base64 string. + */ +- public static WritableMap createOutputTensor(OrtSession.Result result) throws Exception { ++ public static WritableMap createOutputTensor(BlobModule blobModule, OrtSession.Result result) throws Exception { + WritableMap outputTensorMap = Arguments.createMap(); + + Iterator> iterator = result.iterator(); +@@ -115,8 +117,12 @@ public class TensorHelper { + } + outputTensor.putArray("data", dataArray); + } else { +- String data = createOutputTensor(onnxTensor); +- outputTensor.putString("data", data); ++ // Blob ++ byte[] bufferArray = createOutputTensor(onnxTensor); ++ String blobId = blobModule.store(bufferArray); ++ int size = bufferArray.length; ++ outputTensor.putString("data", blobId); ++ outputTensor.putInt("size", size); + } + + outputTensorMap.putMap(outputName, outputTensor); +@@ -164,7 +170,11 @@ public class TensorHelper { tensor = OnnxTensor.createTensor(ortEnvironment, buffer, dims, OnnxJavaType.UINT8); break; } @@ -29,3 +93,102 @@ index 500141a..49b3abd 100644 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: +@@ -177,7 +187,7 @@ public class TensorHelper { + return tensor; + } + +- private static String createOutputTensor(OnnxTensor onnxTensor) throws Exception { ++ private static byte[] createOutputTensor(OnnxTensor onnxTensor) throws Exception { + TensorInfo tensorInfo = onnxTensor.getInfo(); + ByteBuffer buffer = null; + +@@ -224,8 +234,7 @@ public class TensorHelper { + throw new IllegalStateException("Unexpected type: " + tensorInfo.onnxType.toString()); + } + +- String data = Base64.encodeToString(buffer.array(), Base64.DEFAULT); +- return data; ++ return buffer.array(); + } + + private static final Map JsTensorTypeToOnnxTensorTypeMap = +diff --git a/node_modules/onnxruntime-react-native/ios/TensorHelper.mm b/node_modules/onnxruntime-react-native/ios/TensorHelper.mm +index 00c1c79..ed6c81c 100644 +--- a/node_modules/onnxruntime-react-native/ios/TensorHelper.mm ++++ b/node_modules/onnxruntime-react-native/ios/TensorHelper.mm +@@ -2,6 +2,8 @@ + // Licensed under the MIT License. + + #import "TensorHelper.h" ++#import ++#import + #import + + @implementation TensorHelper +@@ -109,8 +111,11 @@ + (NSDictionary *)createOutputTensor:(const std::vector &)outputNa + } + outputTensor[@"data"] = buffer; + } else { +- NSString *data = [self createOutputTensor:value]; +- outputTensor[@"data"] = data; ++ NSData *buffer = [self createOutputTensor:value]; ++ RCTBlobManager* blobManager = [[RCTBridge currentBridge] moduleForClass:RCTBlobManager.class]; ++ NSString* blobId = [blobManager store:buffer]; ++ outputTensor[@"data"] = blobId; ++ outputTensor[@"size"] = [NSNumber numberWithUnsignedInteger:buffer.length]; + } + + outputTensorMap[[NSString stringWithUTF8String:outputName]] = outputTensor; +@@ -170,15 +175,15 @@ + (NSDictionary *)createOutputTensor:(const std::vector &)outputNa + } + } + +-template static NSString *createOutputTensorT(const Ort::Value &tensor) { ++template static NSData *createOutputTensorT(const Ort::Value &tensor) { + const auto data = tensor.GetTensorData(); + NSData *buffer = [NSData dataWithBytesNoCopy:(void *)data + length:tensor.GetTensorTypeAndShapeInfo().GetElementCount() * sizeof(T) + freeWhenDone:false]; +- return [buffer base64EncodedStringWithOptions:0]; ++ return buffer; + } + +-+ (NSString *)createOutputTensor:(const Ort::Value &)tensor { +++ (NSData *)createOutputTensor:(const Ort::Value &)tensor { + ONNXTensorElementDataType tensorType = tensor.GetTensorTypeAndShapeInfo().GetElementType(); + + switch (tensorType) { +diff --git a/node_modules/onnxruntime-react-native/lib/backend.ts b/node_modules/onnxruntime-react-native/lib/backend.ts +index 4ebc364..7aee5a0 100644 +--- a/node_modules/onnxruntime-react-native/lib/backend.ts ++++ b/node_modules/onnxruntime-react-native/lib/backend.ts +@@ -4,6 +4,7 @@ + import {Buffer} from 'buffer'; + import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; + import {Platform} from 'react-native'; ++import {getArrayBufferForBlob} from 'react-native-blob-jsi-helper'; + + import {binding, Binding} from './binding'; + +@@ -98,7 +99,20 @@ class OnnxruntimeSessionHandler implements SessionHandler { + } + } + const input = this.encodeFeedsType(feeds); +- const results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options); ++ let results: Binding.ReturnType = await this.#inferenceSession.run(this.#key, input, outputNames, options); ++ results = Object.entries(results).reduce((acc, [name, result]) => { ++ acc[name] = { ++ ...result, ++ data: getArrayBufferForBlob({ ++ _data: { ++ blobId: result.data, ++ offset: 0, ++ size: result.size, ++ } ++ }), ++ }; ++ return acc; ++ }, {}) + const output = this.decodeReturnType(results); + return output; + } diff --git a/yarn.lock b/yarn.lock index 2c7627d..adc81e3 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6595,6 +6595,11 @@ react-is@^17.0.1: resolved "https://registry.yarnpkg.com/react-is/-/react-is-17.0.2.tgz#e691d4a8e9c789365655539ab372762b0efb54f0" integrity sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w== +react-native-blob-jsi-helper@^0.3.0: + version "0.3.0" + resolved "https://registry.yarnpkg.com/react-native-blob-jsi-helper/-/react-native-blob-jsi-helper-0.3.0.tgz#a57a8467d9b08d620db1d9e546dbbef45e2996d2" + integrity sha512-9ez/zdiHEcuI86ufxSAWqiPEMjhtCW89DHlG3nVPhQ1vBi7cb7/jsrMYILVaNzGsxsW7vPPcMAs9Cd8hxo7M0w== + react-native-codegen@^0.71.5: version "0.71.5" resolved "https://registry.yarnpkg.com/react-native-codegen/-/react-native-codegen-0.71.5.tgz#454a42a891cd4ca5fc436440d301044dc1349c14"