Skip to content

Commit 010089e

Browse files
authored
[Hotfix] Fix the conflicts about ffi-related updated names (#18287)
* Change registration of mock softmax function * Update check_asf_header.sh Remove unnecessary blank line in check_asf_header.sh * Update check_asf_header.sh * fix
1 parent c3b168b commit 010089e

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

python/tvm/relax/relax_to_pyfunc_converter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import tvm
2929
from tvm import relax
30+
from tvm.runtime import empty, from_dlpack, Tensor
3031
from tvm.ir import IRModule, Op
3132

3233

@@ -608,7 +609,7 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any:
608609
for arg in converted_args:
609610
if isinstance(arg, torch.Tensor):
610611
# Convert PyTorch tensor to TVM NDArray via DLPack
611-
tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg))
612+
tvm_arg = from_dlpack(torch.to_dlpack(arg))
612613
tvm_args.append(tvm_arg)
613614
else:
614615
tvm_args.append(arg)
@@ -627,15 +628,15 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any:
627628
return f"<call_tir_error: {func_name} - Cannot determine output shape>"
628629

629630
# Allocate output tensor
630-
output_tensor = tvm.nd.array(tvm.nd.empty(output_shape, dtype="float32"))
631+
output_tensor = empty(output_shape, dtype="float32")
631632
tvm_args.append(output_tensor)
632633

633634
# Call the TIR function
634635
tir_function(*tvm_args)
635636

636637
# The result is in the output_tensor we allocated
637638
# Convert result back to PyTorch tensor via DLPack
638-
return torch.from_dlpack(output_tensor.to_dlpack())
639+
return torch.from_dlpack(output_tensor)
639640

640641
except (RuntimeError, ValueError, TypeError) as error:
641642
return f"<call_tir_error: {func_name} - {error}>"
@@ -669,7 +670,7 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any:
669670
for arg in converted_args:
670671
if isinstance(arg, torch.Tensor):
671672
# Convert PyTorch tensor to TVM NDArray via DLPack
672-
tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg))
673+
tvm_arg = from_dlpack(torch.to_dlpack(arg))
673674
tvm_args.append(tvm_arg)
674675
else:
675676
tvm_args.append(arg)
@@ -678,8 +679,9 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any:
678679
result = packed_function(*tvm_args)
679680

680681
# Convert result back to PyTorch tensor via DLPack
681-
if isinstance(result, tvm.nd.NDArray):
682-
return torch.from_dlpack(result.to_dlpack())
682+
if isinstance(result, Tensor):
683+
# Convert TVM Tensor to PyTorch tensor
684+
return torch.from_dlpack(result)
683685
else:
684686
return result
685687

tests/python/relax/test_relax_to_pyfunc_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def mock_softmax(x, axis):
200200
return x
201201

202202
# Register the function globally
203-
tvm.register_func("my_softmax", mock_softmax)
203+
tvm.register_global_func("my_softmax", mock_softmax)
204204

205205

206206
class TestRelaxToPyFuncConverter:

0 commit comments

Comments
 (0)