Skip to content

Commit

Permalink
[js/rn] Implement blob exchange by JSI instead of use base64 (#16094)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

- Create `OnnxruntimeJSIHelper` native module to provide two JSI
functions
- `jsiOnnxruntimeStoreArrayBuffer`: Store buffer in Blob Manager &
return blob object (iOS: RCTBlobManager, Android: BlobModule)
  - `jsiOnnxruntimeResolveArrayBuffer`: Use blob object to get buffer
- The part of implementation is reference to
[react-native-blob-jsi-helper](https://github.com/mrousavy/react-native-blob-jsi-helper)
- Replace base64 encode/decode
  - `loadModelFromBlob`: Rename from `loadModelFromBase64EncodedBuffer`
  - `run`: Use blob object to replace input.data & results[].data

For [this
context](#16031 (comment)),
it saved a lot of time and avoid JS thread blocking in decode return
type, it is 3700ms -> 5~20ms for the case. (resolve function only takes
0.x ms)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

It’s related to #16031, but not a full implementation for migrate to
JSI.

It just uses JSI through BlobManager to replace the slow part (base64
encode / decode).

Rewriting it entirely in JSI could be complicated, like type convertion
and threading. This PR might be considered a minor change.

/cc @skottmckay
  • Loading branch information
jhen0409 authored Jun 16, 2023
1 parent 9110e5b commit ea1a5cf
Show file tree
Hide file tree
Showing 23 changed files with 935 additions and 141 deletions.
37 changes: 37 additions & 0 deletions js/react_native/android/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
project(OnnxruntimeJSIHelper)
cmake_minimum_required(VERSION 3.9.0)

set (PACKAGE_NAME "onnxruntime-react-native")
set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_CXX_STANDARD 17)

file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath)

include_directories(
"${NODE_MODULES_DIR}/react-native/React"
"${NODE_MODULES_DIR}/react-native/React/Base"
"${NODE_MODULES_DIR}/react-native/ReactCommon/jsi"
)

add_library(onnxruntimejsihelper
SHARED
${libPath}
src/main/cpp/cpp-adapter.cpp
)

# Configure C++ 17
set_target_properties(
onnxruntimejsihelper PROPERTIES
CXX_STANDARD 17
CXX_EXTENSIONS OFF
POSITION_INDEPENDENT_CODE ON
)

find_library(log-lib log)

target_link_libraries(
onnxruntimejsihelper
${log-lib} # <-- Logcat logger
android # <-- Android JNI core
)
71 changes: 69 additions & 2 deletions js/react_native/android/build.gradle
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import java.nio.file.Paths

buildscript {
repositories {
google()
Expand All @@ -20,6 +22,32 @@ def getExtOrIntegerDefault(name) {
return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger()
}

def reactNativeArchitectures() {
def value = project.getProperties().get("reactNativeArchitectures")
return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"]
}

def resolveBuildType() {
Gradle gradle = getGradle()
String tskReqStr = gradle.getStartParameter().getTaskRequests()['args'].toString()
return tskReqStr.contains('Release') ? 'release' : 'debug'
}

static def findNodeModules(baseDir) {
def basePath = baseDir.toPath().normalize()
while (basePath) {
def nodeModulesPath = Paths.get(basePath.toString(), "node_modules")
def reactNativePath = Paths.get(nodeModulesPath.toString(), "react-native")
if (nodeModulesPath.toFile().exists() && reactNativePath.toFile().exists()) {
return nodeModulesPath.toString()
}
basePath = basePath.getParent()
}
throw new GradleException("onnxruntime-react-native: Failed to find node_modules/ path!")
}

def nodeModules = findNodeModules(projectDir);

def checkIfOrtExtensionsEnabled() {
// locate user's project dir
def reactnativeRootDir = project.rootDir.parentFile
Expand All @@ -38,6 +66,9 @@ def checkIfOrtExtensionsEnabled() {

boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled()

def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim()
def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger()

android {
compileSdkVersion getExtOrIntegerDefault('compileSdkVersion')
buildToolsVersion getExtOrDefault('buildToolsVersion')
Expand All @@ -47,6 +78,44 @@ android {
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
externalNativeBuild {
cmake {
cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all"
if (REACT_NATIVE_MINOR_VERSION >= 71) {
// fabricjni required c++_shared
arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
} else {
arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
}
abiFilters (*reactNativeArchitectures())
}
}
}

if (rootProject.hasProperty("ndkPath")) {
ndkPath rootProject.ext.ndkPath
}
if (rootProject.hasProperty("ndkVersion")) {
ndkVersion rootProject.ext.ndkVersion
}

buildFeatures {
prefab true
}

externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}

packagingOptions {
doNotStrip resolveBuildType() == 'debug' ? "**/**/*.so" : ''
excludes = [
"META-INF",
"META-INF/**",
"**/libjsi.so",
]
}

buildTypes {
Expand Down Expand Up @@ -149,8 +218,6 @@ repositories {
}
}

def REACT_NATIVE_VERSION = new File(['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim())

dependencies {
api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
api "org.mockito:mockito-core:2.28.2"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package ai.onnxruntime.reactnative;

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.modules.blob.BlobModule;

public class FakeBlobModule extends BlobModule {

public FakeBlobModule(ReactApplicationContext context) { super(null); }

@Override
public String getName() {
return "BlobModule";
}

public JavaOnlyMap testCreateData(byte[] bytes) {
String blobId = store(bytes);
JavaOnlyMap data = new JavaOnlyMap();
data.putString("blobId", blobId);
data.putInt("offset", 0);
data.putInt("size", bytes.length);
return data;
}

public byte[] testGetData(ReadableMap data) {
String blobId = data.getString("blobId");
int offset = data.getInt("offset");
int size = data.getInt("size");
return resolve(blobId, offset, size);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import android.util.Base64;
import androidx.test.platform.app.InstrumentationRegistry;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.CatalystInstance;
import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand All @@ -29,12 +32,17 @@ public class OnnxruntimeModuleTest {
private ReactApplicationContext reactContext =
new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext());

private FakeBlobModule blobModule;

@Before
public void setUp() {}
public void setUp() {
blobModule = new FakeBlobModule(reactContext);
}

@Test
public void getName() throws Exception {
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
String name = "Onnxruntime";
Assert.assertEquals(ortModule.getName(), name);
}
Expand All @@ -47,6 +55,7 @@ public void onnxruntime_module() throws Exception {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
String sessionKey = "";

// test loadModel()
Expand Down Expand Up @@ -104,8 +113,7 @@ public void onnxruntime_module() throws Exception {
floatBuffer.put(value);
}
floatBuffer.rewind();
String dataEncoded = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array()));

inputDataMap.putMap("input", inputTensorMap);
}
Expand All @@ -124,10 +132,9 @@ public void onnxruntime_module() throws Exception {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat);
String dataEncoded = outputMap.getString("data");
FloatBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT))
.order(ByteOrder.nativeOrder())
.asFloatBuffer();
ReadableMap data = outputMap.getMap("data");
FloatBuffer buffer =
ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
}
Expand Down
Loading

0 comments on commit ea1a5cf

Please sign in to comment.