Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jvm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TVM4J contains three modules:
- core
* It contains all the Java interfaces.
- native
* The JNI native library is compiled in this module. It does not link TVM runtime library (libtvm\_runtime.so for Linux and libtvm\_runtime.dylib for OSX). Instead, you have to specify `libtvm.so.path` which contains the TVM runtime library as Java system property.
* The JNI native library is compiled in this module. Need to expose libtvm_runtime to LD_LIBRARY_PATH
- assembly
* It assembles Java interfaces (core), JNI library (native) and TVM runtime library together. The simplest way to integrate tvm4j in your project is to rely on this module. It automatically extracts the native library to a tempfile and load it.

Expand Down
35 changes: 3 additions & 32 deletions jvm/core/src/main/java/org/apache/tvm/Base.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,37 +87,8 @@ public RefTVMValue() {
}

System.err.println("libtvm4j loads successfully.");

if (loadNativeRuntimeLib) {
String tvmLibFilename = System.getProperty("libtvm.so.path");
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
try {
String runtimeLibname;
String os = System.getProperty("os.name");
// ref: http://lopica.sourceforge.net/os.html
if (os.startsWith("Linux")) {
runtimeLibname = "libtvm_runtime.so";
} else if (os.startsWith("Mac")) {
runtimeLibname = "libtvm_runtime.dylib";
} else {
// TODO(yizhi) support windows later
throw new UnsatisfiedLinkError(os + " not supported currently");
}
NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() {
@Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} else {
_LIB.nativeLibInit(null);
}

// always use linked lib
_LIB.nativeLibInit(null);
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override public void run() {
_LIB.shutdown();
Expand Down Expand Up @@ -170,7 +141,7 @@ private static void tryLoadLibraryXPU(String libname, String arch) throws Unsati
*/
public static void checkCall(int ret) throws TVMError {
if (ret != 0) {
throw new TVMError(_LIB.tvmGetLastError());
throw new TVMError(_LIB.tvmFFIGetLastError());
}
}

Expand Down
31 changes: 23 additions & 8 deletions jvm/core/src/main/java/org/apache/tvm/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,30 @@

package org.apache.tvm;

import org.apache.tvm.rpc.RPC;

import java.util.HashMap;
import java.util.Map;
import org.apache.tvm.rpc.RPC;

public class Device {
/**
* Provides the same information as the C++ enums DLDeviceType and
* TVMDeviceExtType.
*/
static final int kDLCPU = 1, kDLCUDA = 2, kDLCUDAHost = 3, kDLOpenCL = 4, kDLVulkan = 7,
kDLMetal = 8, kDLVPI = 9, kDLROCM = 10, kDLROCMHost = 11, kDLExtDev = 12,
kDLCUDAManaged = 13, kDLOneAPI = 14, kDLWebGPU = 15, kDLHexagon = 16;
static final int kDLCPU = 1;
static final int kDLCUDA = 2;
static final int kDLCUDAHost = 3;
static final int kDLOpenCL = 4;
static final int kDLVulkan = 7;
static final int kDLMetal = 8;
static final int kDLVPI = 9;
static final int kDLROCM = 10;
static final int kDLROCMHost = 11;
static final int kDLExtDev = 12;
static final int kDLCUDAManaged = 13;
static final int kDLOneAPI = 14;
static final int kDLWebGPU = 15;
static final int kDLHexagon = 16;

private static final Map<Integer, String> DEVICE_TYPE_TO_NAME = new HashMap<Integer, String>();
private static final Map<String, Integer> DEVICE_NAME_TO_TYPE = new HashMap<String, Integer>();
Expand Down Expand Up @@ -161,7 +173,8 @@ public Device(String deviceType, int deviceId) {
*/
public boolean exist() {
TVMValue ret =
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(0).invoke();
APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType)
.pushArg(deviceId).pushArg(0).invoke();
return ((TVMValueLong) ret).value != 0;
}

Expand All @@ -171,7 +184,8 @@ public boolean exist() {
*/
public long maxThreadsPerBlock() {
TVMValue ret =
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(1).invoke();
APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType)
.pushArg(deviceId).pushArg(1).invoke();
return ((TVMValueLong) ret).value;
}

Expand All @@ -181,8 +195,9 @@ public long maxThreadsPerBlock() {
*/
public long warpSize() {
TVMValue ret =
APIInternal.get("_GetDeviceAttr").pushArg(deviceType).pushArg(deviceId).pushArg(2).invoke();
return ((TVMValueLong) ret).value;
APIInternal.get("runtime.GetDeviceAttr").pushArg(deviceType)
.pushArg(deviceId).pushArg(2).invoke();
return ret.asLong();
}

/**
Expand Down
Loading
Loading