3232except ImportError :
3333 to_dlpack_legacy = None
3434
35+ try :
36+ from tvm_ffi ._optional_torch_c_dlpack import load_torch_c_dlpack_extension
37+
38+ _FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension ()
39+ except ImportError :
40+ _FASTER_DLPACK_EXTENSION = None
41+
3542
3643class BasePyModule :
3744 """Base class that allows Python functions in IRModule with DLPack conversion.
@@ -369,20 +376,29 @@ def _convert_pytorch_to_tvm(
369376 return self ._convert_single_pytorch_to_tvm (tensors )
370377
371378 def _convert_single_pytorch_to_tvm (self , tensor : Any ) -> Tensor :
372- """Convert a single PyTorch tensor to TVM Tensor with robust fallbacks ."""
379+ """Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter ."""
373380 # pylint: disable=import-outside-toplevel
374381 import torch
375382
376383 if isinstance (tensor , Tensor ):
377384 return tensor
378385 if isinstance (tensor , torch .Tensor ):
379- # 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
386+ # 1. Try faster C++ DLPack converter
387+ if _FASTER_DLPACK_EXTENSION is not None :
388+ try :
389+ dlpack = torch .to_dlpack (tensor )
390+ return tvm .runtime .from_dlpack (dlpack )
391+ except (AttributeError , ValueError ):
392+ pass # Fall through to the next method
393+
394+ # 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
380395 try :
381396 dlpack = torch .to_dlpack (tensor )
382397 return tvm .runtime .from_dlpack (dlpack )
383398 except (AttributeError , ValueError ):
384399 pass # Fall through to the next method
385- # 2. Try legacy `torch.utils.dlpack.to_dlpack`
400+
401+ # 3. Try legacy `torch.utils.dlpack.to_dlpack`
386402 if to_dlpack_legacy :
387403 try :
388404 dlpack = to_dlpack_legacy (tensor )
@@ -392,7 +408,8 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
392408 f"Warning: Legacy DLPack conversion failed ({ error_legacy } ), "
393409 f"using numpy fallback."
394410 )
395- # 3. If all DLPack methods fail, use numpy fallback
411+
412+ # 4. If all DLPack methods fail, use numpy fallback
396413 numpy_array = tensor .detach ().cpu ().numpy ()
397414 return tvm .runtime .tensor (numpy_array , device = self .device )
398415
@@ -406,28 +423,37 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
406423 ) from error
407424
408425 def _convert_tvm_to_pytorch (
409- self , tvm_arrays : Union [Any , List [Any ]]
426+ self , tvm_tensors : Union [Any , List [Any ]]
410427 ) -> Union ["torch.Tensor" , List ["torch.Tensor" ]]:
411428 """Convert TVM Tensors to PyTorch tensors using DLPack."""
412- if isinstance (tvm_arrays , (list , tuple )):
413- return [self ._convert_single_tvm_to_pytorch (arr ) for arr in tvm_arrays ]
414- return self ._convert_single_tvm_to_pytorch (tvm_arrays )
429+ if isinstance (tvm_tensors , (list , tuple )):
430+ return [self ._convert_single_tvm_to_pytorch (tensor ) for tensor in tvm_tensors ]
431+ return self ._convert_single_tvm_to_pytorch (tvm_tensors )
415432
416- def _convert_single_tvm_to_pytorch (self , tvm_array : Any ) -> "torch.Tensor" :
417- """Convert a single TVM Tensor to PyTorch tensor using DLPack."""
433+ def _convert_single_tvm_to_pytorch (self , tvm_tensor : Any ) -> "torch.Tensor" :
434+ """Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter ."""
418435 # pylint: disable=import-outside-toplevel
419436 import torch
420437
421- if isinstance (tvm_array , torch .Tensor ):
422- return tvm_array
423- if not isinstance (tvm_array , Tensor ):
424- return torch .tensor (tvm_array )
438+ if isinstance (tvm_tensor , torch .Tensor ):
439+ return tvm_tensor
440+ if not isinstance (tvm_tensor , Tensor ):
441+ return torch .tensor (tvm_tensor )
442+
443+ # 1. Try faster C++ DLPack converter
444+ if _FASTER_DLPACK_EXTENSION is not None :
445+ try :
446+ return torch .from_dlpack (tvm_tensor )
447+ except (AttributeError , ValueError ):
448+ pass # Fall through to the next method
449+
450+ # 2. Try standard DLPack conversion
425451 try :
426- return torch .from_dlpack (tvm_array )
452+ return torch .from_dlpack (tvm_tensor )
427453 # pylint: disable=broad-exception-caught
428454 except Exception as error :
429455 print (f"Warning: DLPack conversion from TVM failed ({ error } ), using numpy fallback" )
430- numpy_array = tvm_array .numpy ()
456+ numpy_array = tvm_tensor .numpy ()
431457 return torch .from_numpy (numpy_array )
432458
433459 def get_function (self , name : str ) -> Optional [PackedFunc ]:
0 commit comments