Skip to content

Commit 20fc9cd

Browse files
committed
[FFI][JVM] Upgrade tvm4j to latest FFI
This PR updates TVM4J to use the latest FFI
1 parent 942d03c commit 20fc9cd

28 files changed

+554
-611
lines changed

jvm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ TVM4J contains three modules:
3939
- core
4040
* It contains all the Java interfaces.
4141
- native
42-
* The JNI native library is compiled in this module. It does not link TVM runtime library (libtvm\_runtime.so for Linux and libtvm\_runtime.dylib for OSX). Instead, you have to specify `libtvm.so.path` which contains the TVM runtime library as Java system property.
42+
* The JNI native library is compiled in this module. Need to expose libtvm_runtime to LD_LIBRARY_PATH
4343
- assembly
4444
* It assembles Java interfaces (core), JNI library (native) and TVM runtime library together. The simplest way to integrate tvm4j in your project is to rely on this module. It automatically extracts the native library to a tempfile and load it.
4545

jvm/core/src/main/java/org/apache/tvm/Base.java

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -87,37 +87,8 @@ public RefTVMValue() {
8787
}
8888

8989
System.err.println("libtvm4j loads successfully.");
90-
91-
if (loadNativeRuntimeLib) {
92-
String tvmLibFilename = System.getProperty("libtvm.so.path");
93-
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
94-
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
95-
try {
96-
String runtimeLibname;
97-
String os = System.getProperty("os.name");
98-
// ref: http://lopica.sourceforge.net/os.html
99-
if (os.startsWith("Linux")) {
100-
runtimeLibname = "libtvm_runtime.so";
101-
} else if (os.startsWith("Mac")) {
102-
runtimeLibname = "libtvm_runtime.dylib";
103-
} else {
104-
// TODO(yizhi) support windows later
105-
throw new UnsatisfiedLinkError(os + " not supported currently");
106-
}
107-
NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() {
108-
@Override public void invoke(File target) {
109-
System.err.println("Loading tvm runtime from " + target.getPath());
110-
checkCall(_LIB.nativeLibInit(target.getPath()));
111-
}
112-
});
113-
} catch (IOException e) {
114-
throw new RuntimeException(e);
115-
}
116-
}
117-
} else {
118-
_LIB.nativeLibInit(null);
119-
}
120-
90+
// always use linked lib
91+
_LIB.nativeLibInit(null);
12192
Runtime.getRuntime().addShutdownHook(new Thread() {
12293
@Override public void run() {
12394
_LIB.shutdown();
@@ -170,7 +141,7 @@ private static void tryLoadLibraryXPU(String libname, String arch) throws Unsati
170141
*/
171142
public static void checkCall(int ret) throws TVMError {
172143
if (ret != 0) {
173-
throw new TVMError(_LIB.tvmGetLastError());
144+
throw new TVMError(_LIB.tvmFFIGetLastError());
174145
}
175146
}
176147

jvm/core/src/main/java/org/apache/tvm/Device.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,30 @@
1717

1818
package org.apache.tvm;
1919

20+
import org.apache.tvm.rpc.RPC;
21+
2022
import java.util.HashMap;
2123
import java.util.Map;
22-
import org.apache.tvm.rpc.RPC;
2324

2425
public class Device {
2526
/**
2627
* Provides the same information as the C++ enums DLDeviceType and
2728
* TVMDeviceExtType.
2829
*/
29-
static final int kDLCPU = 1, kDLCUDA = 2, kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7,
30-
kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, kDLROCMHost = 11, kDLExtDev = 12,
31-
kDLCUDAManaged = 13, kDLOneAPI = 14, kDLWebGPU = 15, kDLHexagon = 16;
30+
static final int kDLCPU = 1;
31+
static final int kDLCUDA = 2;
32+
static final int kDLCUDAHost = 3;
33+
static final int kDLOpenCL = 4;
34+
static final int kDLVulkan = 7;
35+
static final int kDLMetal = 8;
36+
static final int kDLVPI = 9;
37+
static final int kDLROCM = 10;
38+
static final int kDLROCMHost = 11;
39+
static final int kDLExtDev = 12;
40+
static final int kDLCUDAManaged = 13;
41+
static final int kDLOneAPI = 14;
42+
static final int kDLWebGPU = 15;
43+
static final int kDLHexagon = 16;
3244

3345
private static final Map<Integer, String> DEVICE_TYPE_TO_NAME = new HashMap<Integer, String>();
3446
private static final Map<String, Integer> DEVICE_NAME_TO_TYPE = new HashMap<String, Integer>();
@@ -161,7 +173,8 @@ public Device(String deviceType, int deviceId) {
161173
*/
162174
public boolean exist() {
163175
TVMValue ret =
164-
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke();
176+
APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType)
177+
.pushArg(deviceId).pushArg(0).invoke();
165178
return ((TVMValueLong) ret).value != 0;
166179
}
167180

@@ -171,7 +184,8 @@ public boolean exist() {
171184
*/
172185
public long maxThreadsPerBlock() {
173186
TVMValue ret =
174-
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke();
187+
APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType)
188+
.pushArg(deviceId).pushArg(1).invoke();
175189
return ((TVMValueLong) ret).value;
176190
}
177191

@@ -181,8 +195,9 @@ public long maxThreadsPerBlock() {
181195
*/
182196
public long warpSize() {
183197
TVMValue ret =
184-
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke();
185-
return ((TVMValueLong) ret).value;
198+
APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType)
199+
.pushArg(deviceId).pushArg(2).invoke();
200+
return ret.asLong();
186201
}
187202

188203
/**

0 commit comments

Comments
 (0)