Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/rn] Implement blob exchange by JSI instead of use base64 #16094

Merged
merged 32 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d6035ad
iOS: Implement blob exchange by JSI instead of base64
jhen0409 May 24, 2023
6043c7f
[js/rn] Android: Implement blob exchange by JSI instead of base64
jhen0409 May 25, 2023
d070d33
iOS: Remove unnecessary import
jhen0409 May 25, 2023
1f442b2
Android: Fix cmake build on RN v0.69
jhen0409 May 25, 2023
c21dc3c
Android: Update java tests
jhen0409 May 25, 2023
585d4dc
iOS: Update objective-c tests
jhen0409 May 25, 2023
62893f3
JS: Fix resolved buffer type
jhen0409 May 25, 2023
b5af470
Android: Use getReactApplicationContext() for get blob manager
jhen0409 May 25, 2023
1d336fa
Android: Revert unnecessary changes
jhen0409 May 25, 2023
3c05e1d
iOS: Correct copyright for new test files
jhen0409 May 25, 2023
7897572
Android: Fix build for React Native v0.71
jhen0409 May 26, 2023
0728034
TS: Remove global functions after installation
jhen0409 May 27, 2023
d66f095
iOS: Delete blobManager ref on dealloc
jhen0409 May 27, 2023
034d201
Android: Correct JSI function names
jhen0409 Jun 1, 2023
e5e5e6b
iOS: Correct JSI function names & add args count check
jhen0409 Jun 1, 2023
c6309db
Cleanup unnecessary includes & imports & logs
jhen0409 Jun 2, 2023
d4660bb
Android: Refactor getBytesFromBlob / createBlob cpp functions
jhen0409 Jun 2, 2023
7dc50fc
Check if blob manager not initialized
jhen0409 Jun 2, 2023
f3202e9
Android: Correct JSI functions name
jhen0409 Jun 2, 2023
7c87980
Android: Remove unnecessary Base64 import
jhen0409 Jun 2, 2023
fc0a513
Update comments
jhen0409 Jun 2, 2023
277c0b7
allocatons -> allocations
jhen0409 Jun 2, 2023
46508e5
Fix typo
jhen0409 Jun 2, 2023
7345896
Merge branch 'main' into jhen-rn-use-blob
jhen0409 Jun 9, 2023
2754d60
Run format
jhen0409 Jun 9, 2023
80e98ed
TS: Update types & comments
jhen0409 Jun 9, 2023
9f0f515
Android: Remove unnecessary script
jhen0409 Jun 9, 2023
1cb5f63
Fix lint errors & format again
jhen0409 Jun 10, 2023
d212271
Merge branch 'main' into jhen-rn-use-blob
jhen0409 Jun 12, 2023
f071c74
Fix lint error
jhen0409 Jun 12, 2023
f08ab3a
Merge remote-tracking branch 'origin' into jhen-rn-use-blob
jhen0409 Jun 14, 2023
7f9c916
Merge branch 'main' into jhen-rn-use-blob
jhen0409 Jun 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions js/react_native/android/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
project(OnnxruntimeJSIHelper)
cmake_minimum_required(VERSION 3.9.0)

set (PACKAGE_NAME "onnxruntime-react-native")
set (BUILD_DIR ${CMAKE_SOURCE_DIR}/build)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_CXX_STANDARD 17)

file(TO_CMAKE_PATH "${NODE_MODULES_DIR}/react-native/ReactCommon/jsi/jsi/jsi.cpp" libPath)

include_directories(
"${NODE_MODULES_DIR}/react-native/React"
"${NODE_MODULES_DIR}/react-native/React/Base"
"${NODE_MODULES_DIR}/react-native/ReactCommon/jsi"
)

add_library(onnxruntimejsihelper
SHARED
${libPath}
src/main/cpp/cpp-adapter.cpp
)

# Configure C++ 17
set_target_properties(
onnxruntimejsihelper PROPERTIES
CXX_STANDARD 17
CXX_EXTENSIONS OFF
POSITION_INDEPENDENT_CODE ON
)

find_library(log-lib log)

target_link_libraries(
onnxruntimejsihelper
${log-lib} # <-- Logcat logger
android # <-- Android JNI core
)
71 changes: 69 additions & 2 deletions js/react_native/android/build.gradle
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import java.nio.file.Paths

buildscript {
repositories {
google()
Expand All @@ -20,6 +22,32 @@ def getExtOrIntegerDefault(name) {
return rootProject.ext.has(name) ? rootProject.ext.get(name) : (project.properties['OnnxruntimeModule_' + name]).toInteger()
}

def reactNativeArchitectures() {
def value = project.getProperties().get("reactNativeArchitectures")
return value ? value.split(",") : ["armeabi-v7a", "x86", "x86_64", "arm64-v8a"]
}

def resolveBuildType() {
Gradle gradle = getGradle()
String tskReqStr = gradle.getStartParameter().getTaskRequests()['args'].toString()
return tskReqStr.contains('Release') ? 'release' : 'debug'
}

static def findNodeModules(baseDir) {
def basePath = baseDir.toPath().normalize()
while (basePath) {
def nodeModulesPath = Paths.get(basePath.toString(), "node_modules")
def reactNativePath = Paths.get(nodeModulesPath.toString(), "react-native")
if (nodeModulesPath.toFile().exists() && reactNativePath.toFile().exists()) {
return nodeModulesPath.toString()
}
basePath = basePath.getParent()
}
throw new GradleException("onnxruntime-react-native: Failed to find node_modules/ path!")
}

def nodeModules = findNodeModules(projectDir);

def checkIfOrtExtensionsEnabled() {
// locate user's project dir
def reactnativeRootDir = project.rootDir.parentFile
Expand All @@ -38,6 +66,9 @@ def checkIfOrtExtensionsEnabled() {

boolean ortExtensionsEnabled = checkIfOrtExtensionsEnabled()

def REACT_NATIVE_VERSION = ['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim()
def REACT_NATIVE_MINOR_VERSION = REACT_NATIVE_VERSION.split("\\.")[1].toInteger()

android {
compileSdkVersion getExtOrIntegerDefault('compileSdkVersion')
buildToolsVersion getExtOrDefault('buildToolsVersion')
Expand All @@ -47,6 +78,44 @@ android {
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
externalNativeBuild {
cmake {
cppFlags "-O2 -frtti -fexceptions -Wall -Wno-unused-variable -fstack-protector-all"
if (REACT_NATIVE_MINOR_VERSION >= 71) {
// fabricjni required c++_shared
arguments "-DANDROID_STL=c++_shared", "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
} else {
arguments "-DNODE_MODULES_DIR=${nodeModules}", "-DORT_EXTENSIONS_ENABLED=${ortExtensionsEnabled}"
}
abiFilters (*reactNativeArchitectures())
}
}
}

if (rootProject.hasProperty("ndkPath")) {
ndkPath rootProject.ext.ndkPath
}
if (rootProject.hasProperty("ndkVersion")) {
ndkVersion rootProject.ext.ndkVersion
}

buildFeatures {
prefab true
}

externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}

packagingOptions {
doNotStrip resolveBuildType() == 'debug' ? "**/**/*.so" : ''
excludes = [
"META-INF",
"META-INF/**",
"**/libjsi.so",
]
}

buildTypes {
Expand Down Expand Up @@ -149,8 +218,6 @@ repositories {
}
}

def REACT_NATIVE_VERSION = new File(['node', '--print', "JSON.parse(require('fs').readFileSync(require.resolve('react-native/package.json'), 'utf-8')).version"].execute(null, rootDir).text.trim())

dependencies {
api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
api "org.mockito:mockito-core:2.28.2"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package ai.onnxruntime.reactnative;

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.modules.blob.BlobModule;

public class FakeBlobModule extends BlobModule {

public FakeBlobModule(ReactApplicationContext context) { super(null); }

@Override
public String getName() {
return "BlobModule";
}

public JavaOnlyMap testCreateData(byte[] bytes) {
String blobId = store(bytes);
JavaOnlyMap data = new JavaOnlyMap();
data.putString("blobId", blobId);
data.putInt("offset", 0);
data.putInt("size", bytes.length);
return data;
}

public byte[] testGetData(ReadableMap data) {
String blobId = data.getString("blobId");
int offset = data.getInt("offset");
int size = data.getInt("size");
return resolve(blobId, offset, size);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import android.util.Base64;
import androidx.test.platform.app.InstrumentationRegistry;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.CatalystInstance;
import com.facebook.react.bridge.JavaOnlyArray;
import com.facebook.react.bridge.JavaOnlyMap;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand All @@ -29,12 +32,17 @@ public class OnnxruntimeModuleTest {
private ReactApplicationContext reactContext =
new ReactApplicationContext(InstrumentationRegistry.getInstrumentation().getContext());

private FakeBlobModule blobModule;

@Before
public void setUp() {}
public void setUp() {
blobModule = new FakeBlobModule(reactContext);
}

@Test
public void getName() throws Exception {
OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
String name = "Onnxruntime";
Assert.assertEquals(ortModule.getName(), name);
}
Expand All @@ -47,6 +55,7 @@ public void onnxruntime_module() throws Exception {
when(Arguments.createArray()).thenAnswer(i -> new JavaOnlyArray());

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
String sessionKey = "";

// test loadModel()
Expand Down Expand Up @@ -104,8 +113,7 @@ public void onnxruntime_module() throws Exception {
floatBuffer.put(value);
}
floatBuffer.rewind();
String dataEncoded = Base64.encodeToString(buffer.array(), Base64.DEFAULT);
inputTensorMap.putString("data", dataEncoded);
inputTensorMap.putMap("data", blobModule.testCreateData(buffer.array()));

inputDataMap.putMap("input", inputTensorMap);
}
Expand All @@ -124,10 +132,9 @@ public void onnxruntime_module() throws Exception {
Assert.assertEquals(outputMap.getArray("dims").getInt(i), dims[i]);
}
Assert.assertEquals(outputMap.getString("type"), TensorHelper.JsTensorTypeFloat);
String dataEncoded = outputMap.getString("data");
FloatBuffer buffer = ByteBuffer.wrap(Base64.decode(dataEncoded, Base64.DEFAULT))
.order(ByteOrder.nativeOrder())
.asFloatBuffer();
ReadableMap data = outputMap.getMap("data");
FloatBuffer buffer =
ByteBuffer.wrap(blobModule.testGetData(data)).order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int i = 0; i < 5; ++i) {
Assert.assertEquals(buffer.get(i), inputData[i], 1e-6f);
}
Expand Down
Loading