Skip to content

Commit e21b6a2

Browse files
authored
[Relax] Update BasePyModule with faster DLPack converter for tensor conversion (#18331)
This PR enhances `BasePyModule` by integrating a faster DLPack converter for efficient tensor conversion between TVM and PyTorch following #18306.
1 parent e7bcf17 commit e21b6a2

File tree

3 files changed

+47
-65
lines changed

3 files changed

+47
-65
lines changed

python/tvm/relax/base_py_module.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
except ImportError:
3333
to_dlpack_legacy = None
3434

35+
try:
36+
from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension
37+
38+
_FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension()
39+
except ImportError:
40+
_FASTER_DLPACK_EXTENSION = None
41+
3542

3643
class BasePyModule:
3744
"""Base class that allows Python functions in IRModule with DLPack conversion.
@@ -369,20 +376,29 @@ def _convert_pytorch_to_tvm(
369376
return self._convert_single_pytorch_to_tvm(tensors)
370377

371378
def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
372-
"""Convert a single PyTorch tensor to TVM Tensor with robust fallbacks."""
379+
"""Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter."""
373380
# pylint: disable=import-outside-toplevel
374381
import torch
375382

376383
if isinstance(tensor, Tensor):
377384
return tensor
378385
if isinstance(tensor, torch.Tensor):
379-
# 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
386+
# 1. Try faster C++ DLPack converter
387+
if _FASTER_DLPACK_EXTENSION is not None:
388+
try:
389+
dlpack = torch.to_dlpack(tensor)
390+
return tvm.runtime.from_dlpack(dlpack)
391+
except (AttributeError, ValueError):
392+
pass # Fall through to the next method
393+
394+
# 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
380395
try:
381396
dlpack = torch.to_dlpack(tensor)
382397
return tvm.runtime.from_dlpack(dlpack)
383398
except (AttributeError, ValueError):
384399
pass # Fall through to the next method
385-
# 2. Try legacy `torch.utils.dlpack.to_dlpack`
400+
401+
# 3. Try legacy `torch.utils.dlpack.to_dlpack`
386402
if to_dlpack_legacy:
387403
try:
388404
dlpack = to_dlpack_legacy(tensor)
@@ -392,7 +408,8 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
392408
f"Warning: Legacy DLPack conversion failed ({error_legacy}), "
393409
f"using numpy fallback."
394410
)
395-
# 3. If all DLPack methods fail, use numpy fallback
411+
412+
# 4. If all DLPack methods fail, use numpy fallback
396413
numpy_array = tensor.detach().cpu().numpy()
397414
return tvm.runtime.tensor(numpy_array, device=self.device)
398415

@@ -406,28 +423,37 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
406423
) from error
407424

408425
def _convert_tvm_to_pytorch(
409-
self, tvm_arrays: Union[Any, List[Any]]
426+
self, tvm_tensors: Union[Any, List[Any]]
410427
) -> Union["torch.Tensor", List["torch.Tensor"]]:
411428
"""Convert TVM Tensors to PyTorch tensors using DLPack."""
412-
if isinstance(tvm_arrays, (list, tuple)):
413-
return [self._convert_single_tvm_to_pytorch(arr) for arr in tvm_arrays]
414-
return self._convert_single_tvm_to_pytorch(tvm_arrays)
429+
if isinstance(tvm_tensors, (list, tuple)):
430+
return [self._convert_single_tvm_to_pytorch(tensor) for tensor in tvm_tensors]
431+
return self._convert_single_tvm_to_pytorch(tvm_tensors)
415432

416-
def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> "torch.Tensor":
417-
"""Convert a single TVM Tensor to PyTorch tensor using DLPack."""
433+
def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> "torch.Tensor":
434+
"""Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter."""
418435
# pylint: disable=import-outside-toplevel
419436
import torch
420437

421-
if isinstance(tvm_array, torch.Tensor):
422-
return tvm_array
423-
if not isinstance(tvm_array, Tensor):
424-
return torch.tensor(tvm_array)
438+
if isinstance(tvm_tensor, torch.Tensor):
439+
return tvm_tensor
440+
if not isinstance(tvm_tensor, Tensor):
441+
return torch.tensor(tvm_tensor)
442+
443+
# 1. Try faster C++ DLPack converter
444+
if _FASTER_DLPACK_EXTENSION is not None:
445+
try:
446+
return torch.from_dlpack(tvm_tensor)
447+
except (AttributeError, ValueError):
448+
pass # Fall through to the next method
449+
450+
# 2. Try standard DLPack conversion
425451
try:
426-
return torch.from_dlpack(tvm_array)
452+
return torch.from_dlpack(tvm_tensor)
427453
# pylint: disable=broad-exception-caught
428454
except Exception as error:
429455
print(f"Warning: DLPack conversion from TVM failed ({error}), using numpy fallback")
430-
numpy_array = tvm_array.numpy()
456+
numpy_array = tvm_tensor.numpy()
431457
return torch.from_numpy(numpy_array)
432458

433459
def get_function(self, name: str) -> Optional[PackedFunc]:

tests/python/relax/test_base_py_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,4 +203,4 @@ def my_softmax(tensor, dim):
203203

204204

205205
if __name__ == "__main__":
206-
pytest.main([__file__])
206+
tvm.testing.main()

tests/python/relax/test_base_py_module_printer.py

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -420,54 +420,6 @@ def safe_transform(data: T.handle, output: T.handle):
420420
Output[i] = 0.0
421421

422422

423-
if __name__ == "__main__":
424-
# This allows the file to be run directly for debugging
425-
# In normal pytest usage, these classes are automatically tested by TVMScript
426-
print("All test modules defined successfully!")
427-
print("TVMScript will automatically validate these modules during testing.")
428-
429-
# Demo the printer functionality
430-
print("\n" + "=" * 60)
431-
print("DEMO: BasePyModule Printer Functionality")
432-
print("=" * 60)
433-
434-
# Test the printer with SimplePyFuncModule
435-
try:
436-
ir_mod = SimplePyFuncModule
437-
device = tvm.cpu()
438-
module = BasePyModule(ir_mod, device)
439-
440-
print("\n1. Testing script() method:")
441-
print("-" * 40)
442-
script_output = module.script()
443-
print(script_output[:500] + "..." if len(script_output) > 500 else script_output)
444-
445-
print("\n2. Testing show() method:")
446-
print("-" * 40)
447-
module.show()
448-
449-
print("\n3. Python functions found in pyfuncs:")
450-
print("-" * 40)
451-
if hasattr(ir_mod, "pyfuncs"):
452-
for name, func in ir_mod.pyfuncs.items():
453-
print(f" - {name}: {func}")
454-
else:
455-
print(" No pyfuncs attribute found")
456-
457-
except Exception as e:
458-
print(f"Demo failed: {e}")
459-
print("This is expected for testing-only TVMScript code.")
460-
461-
# Run all tests using tvm.testing.main()
462-
print("\n" + "=" * 60)
463-
print("Running all tests with tvm.testing.main()...")
464-
print("=" * 60)
465-
466-
import tvm.testing
467-
468-
tvm.testing.main()
469-
470-
471423
# Pytest test functions to verify the classes work correctly
472424
def test_simple_pyfunc_module_creation():
473425
"""Test that SimplePyFuncModule can be created."""
@@ -849,3 +801,7 @@ def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32
849801

850802
# Use numpy for comparison since we have numpy arrays
851803
np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5)
804+
805+
806+
if __name__ == "__main__":
807+
tvm.testing.main()

0 commit comments

Comments
 (0)