Skip to content

Commit 1b6adcd

Browse files
authored
[ROCm] Fix ROCm build after FFI refactor (#18029)
This PR fixes a few places that failed compilation for ROCm after recent FFI refactors.
1 parent eca92bd commit 1b6adcd

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

src/relax/backend/contrib/hipblas/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ Array<runtime::Module> HipblasCompiler(Array<Function> functions, Map<String, ff
9696
auto constant_names = serializer.GetConstantNames();
9797
const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.HipblasJSONRuntimeCreate");
9898
auto func_name = GetExtSymbol(func);
99-
compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
99+
compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast<runtime::Module>());
100100
}
101101

102102
return compiled_functions;

src/runtime/contrib/hipblas/hipblas.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ inline void CallGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t hdl)
300300
<< "leading dimension must divide 4 for int8 gemm";
301301
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)
302302
<< "leading dimension must divide 4 for int8 gemm";
303-
double alpha = args.size() > 5 ? args[5] : 1.0;
304-
double beta = args.size() > 6 ? args[6] : 0.0;
303+
double alpha = args.size() > 5 ? args[5].cast<double>() : 1.0;
304+
double beta = args.size() > 6 ? args[6].cast<double>() : 0.0;
305305

306306
hipblasDatatype_t hip_in_type = GetHipBlasDataType(A->dtype);
307307
hipblasDatatype_t hip_out_type = GetHipBlasDataType(C->dtype);
@@ -359,8 +359,8 @@ inline void CallBatchGemmEx(ffi::PackedArgs args, ffi::Any* ret, hipblasHandle_t
359359
<< "leading dimension must divide 4 for int8 gemm";
360360
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride3D(B) % 4 == 0)
361361
<< "leading dimension must divide 4 for int8 gemm";
362-
double alpha = args.size() > 5 ? args[5] : 1.0;
363-
double beta = args.size() > 6 ? args[6] : 0.0;
362+
double alpha = args.size() > 5 ? args[5].cast<double>() : 1.0;
363+
double beta = args.size() > 6 ? args[6].cast<double>() : 0.0;
364364

365365
int A_stride = A->shape[1] * A->shape[2];
366366
int B_stride = B->shape[1] * B->shape[2];

src/runtime/contrib/hipblas/hipblas_json_runtime.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,10 @@ class HipblasJSONRuntime : public JSONRuntimeBase {
7272
for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
7373
auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
7474
: EntryID(outputs_[i - input_var_eid_.size()]);
75-
ICHECK(args[i].type_code() == kTVMNDArrayHandle || args[i].type_code() == kTVMDLTensorHandle)
76-
<< "Expect NDArray or DLTensor as inputs";
7775

7876
const DLTensor* arg;
79-
if (args[i].IsObjectRef<NDArray>()) {
80-
NDArray arr = args[i];
77+
if (auto opt_nd = args[i].as<NDArray>()) {
78+
NDArray arr = opt_nd.value();
8179
arg = arr.operator->();
8280
} else {
8381
arg = args[i].cast<DLTensor*>();

0 commit comments

Comments
 (0)