Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 148 additions & 19 deletions jvm/core/src/main/java/ml/dmlc/tvm/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import java.util.Collections;
import java.util.List;

public class Function {
/**
* TVM Packed Function.
*/
public class Function extends TVMValue {
final long handle;
public final boolean isResident;
private boolean isReleased = false;
Expand Down Expand Up @@ -76,24 +79,41 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a
* @param handle the handle to the underlying function.
* @param isResident Whether this is a resident function in jvm
*/
public Function(long handle, boolean isResident) {
Function(long handle, boolean isResident) {
super(TypeCode.FUNC_HANDLE);
this.handle = handle;
this.isResident = isResident;
}

Function(long handle) {
this(handle, false);
}

@Override protected void finalize() throws Throwable {
release();
super.finalize();
}

/**
* Easy for user to get the instance from returned TVMValue.
* @return this
*/
@Override public Function asFunction() {
return this;
}

@Override long asHandle() {
return handle;
}

/**
* Release the Function.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
@Override public void release() {
if (!isReleased) {
if (!isResident) {
Base.checkCall(Base._LIB.tvmFuncFree(handle));
Expand Down Expand Up @@ -167,34 +187,143 @@ public Function pushArg(String arg) {
* @param arg NDArray.
* @return this
*/
public Function pushArg(NDArray arg) {
public Function pushArg(NDArrayBase arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.ARRAY_HANDLE.id);
return this;
}

/**
* Push argument to the function.
* @param arg Module.
* @return this
*/
public Function pushArg(Module arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id);
return this;
}

/**
* Push argument to the function.
* @param arg Function.
* @return this
*/
public Function pushArg(Function arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id);
return this;
}

/**
* Push argument to the function.
* @param arg bytes.
* @return this
*/
public Function pushArg(byte[] arg) {
Base._LIB.tvmFuncPushArgBytes(arg);
return this;
}

/**
* Invoke function with arguments.
* @param args Can be Integer, Long, Float, Double, String, NDArray.
* @return the result.
*/
public TVMValue call(Object... args) {
for (Object arg : args) {
if (arg instanceof Integer) {
pushArg((Integer) arg);
} else if (arg instanceof Long) {
pushArg((Long) arg);
} else if (arg instanceof Float) {
pushArg((Float) arg);
} else if (arg instanceof Double) {
pushArg((Double) arg);
} else if (arg instanceof String) {
pushArg((String) arg);
} else if (arg instanceof NDArray) {
pushArg((NDArray) arg);
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
pushArgToStack(arg);
}
return invoke();
}

private static void pushArgToStack(Object arg) {
if (arg instanceof Integer) {
Base._LIB.tvmFuncPushArgLong((Integer) arg);
} else if (arg instanceof Long) {
Base._LIB.tvmFuncPushArgLong((Long) arg);
} else if (arg instanceof Float) {
Base._LIB.tvmFuncPushArgDouble((Float) arg);
} else if (arg instanceof Double) {
Base._LIB.tvmFuncPushArgDouble((Double) arg);
} else if (arg instanceof String) {
Base._LIB.tvmFuncPushArgString((String) arg);
} else if (arg instanceof byte[]) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
Base._LIB.tvmFuncPushArgHandle(((NDArrayBase) arg).handle, TypeCode.ARRAY_HANDLE.id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
case UINT:
case INT:
Base._LIB.tvmFuncPushArgLong(tvmArg.asLong());
break;
case FLOAT:
Base._LIB.tvmFuncPushArgDouble(tvmArg.asDouble());
break;
case STR:
Base._LIB.tvmFuncPushArgString(tvmArg.asString());
break;
case BYTES:
Base._LIB.tvmFuncPushArgBytes(tvmArg.asBytes());
break;
case ARRAY_HANDLE:
case MODULE_HANDLE:
case FUNC_HANDLE:
Base._LIB.tvmFuncPushArgHandle(tvmArg.asHandle(), tvmArg.typeCode.id);
break;
default:
throw new IllegalArgumentException("Invalid argument: " + arg);
}
} else {
throw new IllegalArgumentException("Invalid argument: " + arg);
}
}

public static interface Callback {
public Object invoke(TVMValue... args);
}

/**
* Register user-defined global function.
* @param name The function name.
* @param function The function to be registered.
* @param override Whether override existing entry.
*/
public static void register(String name, Callback function, boolean override) {
Base.RefLong createdFuncHandleRef = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef));
int ioverride = override ? 1 : 0;
Base.checkCall(Base._LIB.tvmFuncRegisterGlobal(name, createdFuncHandleRef.value, ioverride));
}

/**
* Register user-defined global function, do not override existing entry.
* @param name The function name.
* @param function The function to be registered.
*/
public static void register(String name, Callback function) {
register(name, function, false);
}

/**
* Convert a Java function to TVM function.
* @param function Java function.
* @return TVM function.
*/
public static Function convertFunc(Callback function) {
Base.RefLong createdFuncHandleRef = new Base.RefLong();
Base.checkCall(Base._LIB.tvmFuncCreateFromCFunc(function, createdFuncHandleRef));
return new Function(createdFuncHandleRef.value);
}

private static Object invokeRegisteredCbFunc(Callback cb, TVMValue[] args) {
if (cb == null) {
System.err.println("[ERROR] Failed to get registered function");
return null;
}
return cb.invoke(args);
}
}
55 changes: 28 additions & 27 deletions jvm/core/src/main/java/ml/dmlc/tvm/LibInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,56 +20,57 @@
import java.util.List;

class LibInfo {
public native int nativeLibInit(String tvmLibFile);
native int nativeLibInit(String tvmLibFile);

public native int shutdown();
native int shutdown();

public native String tvmGetLastError();
native String tvmGetLastError();

// Function
public native void tvmFuncPushArgLong(long arg);
native void tvmFuncPushArgLong(long arg);

public native void tvmFuncPushArgDouble(double arg);
native void tvmFuncPushArgDouble(double arg);

public native void tvmFuncPushArgString(String arg);
native void tvmFuncPushArgString(String arg);

public native void tvmFuncPushArgHandle(long arg, int argType);
native void tvmFuncPushArgBytes(byte[] arg);

public native int tvmFuncListGlobalNames(List<String> funcNames);
native void tvmFuncPushArgHandle(long arg, int argType);

public native int tvmFuncFree(long handle);
native int tvmFuncListGlobalNames(List<String> funcNames);

public native int tvmFuncGetGlobal(String name, Base.RefLong handle);
native int tvmFuncFree(long handle);

public native int tvmFuncCall(long handle, Base.RefTVMValue retVal);
native int tvmFuncGetGlobal(String name, Base.RefLong handle);

native int tvmFuncCall(long handle, Base.RefTVMValue retVal);

native int tvmFuncCreateFromCFunc(Function.Callback function, Base.RefLong handle);

native int tvmFuncRegisterGlobal(String name, long handle, int override);

// Module
public native int tvmModFree(long handle);
native int tvmModFree(long handle);

public native int tvmModGetFunction(long handle, String name,
native int tvmModGetFunction(long handle, String name,
int queryImports, Base.RefLong retHandle);

public native int tvmModImport(long mod, long dep);
native int tvmModImport(long mod, long dep);

// NDArray
public native int tvmArrayFree(long handle);
native int tvmArrayFree(long handle);

public native int tvmArrayAlloc(long[] shape,
int dtypeCode,
int dtypeBits,
int dtypeLanes,
int deviceType,
int deviceId,
Base.RefLong refHandle);
native int tvmArrayAlloc(long[] shape, int dtypeCode, int dtypeBits, int dtypeLanes,
int deviceType, int deviceId, Base.RefLong refHandle);

public native int tvmArrayGetShape(long handle, List<Long> shape);
native int tvmArrayGetShape(long handle, List<Long> shape);

public native int tvmArrayCopyFromTo(long from, long to);
native int tvmArrayCopyFromTo(long from, long to);

public native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);
native int tvmArrayCopyFromJArray(byte[] fromRaw, long from, long to);

public native int tvmArrayCopyToJArray(long from, byte[] to);
native int tvmArrayCopyToJArray(long from, byte[] to);

// TVMContext
public native int tvmSynchronize(int deviceType, int deviceId);
native int tvmSynchronize(int deviceType, int deviceId);
}
19 changes: 16 additions & 3 deletions jvm/core/src/main/java/ml/dmlc/tvm/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
/**
* Container of compiled functions of TVM.
*/
public class Module {
public class Module extends TVMValue {
public final long handle;
private boolean isReleased = false;

Expand All @@ -44,7 +44,8 @@ private static Function getApi(String name) {
return func;
}

public Module(long handle) {
Module(long handle) {
super(TypeCode.MODULE_HANDLE);
this.handle = handle;
}

Expand All @@ -56,14 +57,26 @@ public Module(long handle) {
super.finalize();
}

/**
* Easy for user to get the instance from returned TVMValue.
* @return this
*/
@Override public Module asModule() {
return this;
}

@Override long asHandle() {
return handle;
}

/**
* Release the Module.
* <p>
* We highly recommend you to do this manually since the GC strategy is lazy
* and `finalize()` is not guaranteed to be called when GC happens.
* </p>
*/
public void release() {
@Override public void release() {
if (!isReleased) {
Base.checkCall(Base._LIB.tvmModFree(handle));
isReleased = true;
Expand Down
Loading