2727
2828import tvm
2929from tvm import relax
30+ from tvm .runtime import empty , from_dlpack , Tensor
3031from tvm .ir import IRModule , Op
3132
3233
@@ -608,7 +609,7 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any:
608609 for arg in converted_args :
609610 if isinstance (arg , torch .Tensor ):
610611 # Convert PyTorch tensor to TVM NDArray via DLPack
611- tvm_arg = tvm . nd . from_dlpack (torch .to_dlpack (arg ))
612+ tvm_arg = from_dlpack (torch .to_dlpack (arg ))
612613 tvm_args .append (tvm_arg )
613614 else :
614615 tvm_args .append (arg )
@@ -627,15 +628,15 @@ def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any:
627628 return f"<call_tir_error: { func_name } - Cannot determine output shape>"
628629
629630 # Allocate output tensor
630- output_tensor = tvm . nd . array ( tvm . nd . empty (output_shape , dtype = "float32" ) )
631+ output_tensor = empty (output_shape , dtype = "float32" )
631632 tvm_args .append (output_tensor )
632633
633634 # Call the TIR function
634635 tir_function (* tvm_args )
635636
636637 # The result is in the output_tensor we allocated
637638 # Convert result back to PyTorch tensor via DLPack
638- return torch .from_dlpack (output_tensor . to_dlpack () )
639+ return torch .from_dlpack (output_tensor )
639640
640641 except (RuntimeError , ValueError , TypeError ) as error :
641642 return f"<call_tir_error: { func_name } - { error } >"
@@ -669,7 +670,7 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any:
669670 for arg in converted_args :
670671 if isinstance (arg , torch .Tensor ):
671672 # Convert PyTorch tensor to TVM NDArray via DLPack
672- tvm_arg = tvm . nd . from_dlpack (torch .to_dlpack (arg ))
673+ tvm_arg = from_dlpack (torch .to_dlpack (arg ))
673674 tvm_args .append (tvm_arg )
674675 else :
675676 tvm_args .append (arg )
@@ -678,8 +679,9 @@ def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any:
678679 result = packed_function (* tvm_args )
679680
680681 # Convert result back to PyTorch tensor via DLPack
681- if isinstance (result , tvm .nd .NDArray ):
682- return torch .from_dlpack (result .to_dlpack ())
682+ if isinstance (result , Tensor ):
683+ # Convert TVM Tensor to PyTorch tensor
684+ return torch .from_dlpack (result )
683685 else :
684686 return result
685687
0 commit comments