Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 42 additions & 16 deletions python/tvm/relax/base_py_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
except ImportError:
to_dlpack_legacy = None

try:
from tvm_ffi._optional_torch_c_dlpack import load_torch_c_dlpack_extension

_FASTER_DLPACK_EXTENSION = load_torch_c_dlpack_extension()
except ImportError:
_FASTER_DLPACK_EXTENSION = None


class BasePyModule:
"""Base class that allows Python functions in IRModule with DLPack conversion.
Expand Down Expand Up @@ -369,20 +376,29 @@ def _convert_pytorch_to_tvm(
return self._convert_single_pytorch_to_tvm(tensors)

def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
"""Convert a single PyTorch tensor to TVM Tensor with robust fallbacks."""
"""Convert a single PyTorch tensor to TVM Tensor with faster DLPack converter."""
# pylint: disable=import-outside-toplevel
import torch

if isinstance(tensor, Tensor):
return tensor
if isinstance(tensor, torch.Tensor):
# 1. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
# 1. Try faster C++ DLPack converter
if _FASTER_DLPACK_EXTENSION is not None:
try:
dlpack = torch.to_dlpack(tensor)
return tvm.runtime.from_dlpack(dlpack)
except (AttributeError, ValueError):
pass # Fall through to the next method

# 2. Try modern `torch.to_dlpack` (preferred for PyTorch >= 1.7)
try:
dlpack = torch.to_dlpack(tensor)
return tvm.runtime.from_dlpack(dlpack)
except (AttributeError, ValueError):
pass # Fall through to the next method
# 2. Try legacy `torch.utils.dlpack.to_dlpack`

# 3. Try legacy `torch.utils.dlpack.to_dlpack`
if to_dlpack_legacy:
try:
dlpack = to_dlpack_legacy(tensor)
Expand All @@ -392,7 +408,8 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
f"Warning: Legacy DLPack conversion failed ({error_legacy}), "
f"using numpy fallback."
)
# 3. If all DLPack methods fail, use numpy fallback

# 4. If all DLPack methods fail, use numpy fallback
numpy_array = tensor.detach().cpu().numpy()
return tvm.runtime.tensor(numpy_array, device=self.device)

Expand All @@ -406,28 +423,37 @@ def _convert_single_pytorch_to_tvm(self, tensor: Any) -> Tensor:
) from error

def _convert_tvm_to_pytorch(
self, tvm_arrays: Union[Any, List[Any]]
self, tvm_tensors: Union[Any, List[Any]]
) -> Union["torch.Tensor", List["torch.Tensor"]]:
"""Convert TVM Tensors to PyTorch tensors using DLPack."""
if isinstance(tvm_arrays, (list, tuple)):
return [self._convert_single_tvm_to_pytorch(arr) for arr in tvm_arrays]
return self._convert_single_tvm_to_pytorch(tvm_arrays)
if isinstance(tvm_tensors, (list, tuple)):
return [self._convert_single_tvm_to_pytorch(tensor) for tensor in tvm_tensors]
return self._convert_single_tvm_to_pytorch(tvm_tensors)

def _convert_single_tvm_to_pytorch(self, tvm_array: Any) -> "torch.Tensor":
"""Convert a single TVM Tensor to PyTorch tensor using DLPack."""
def _convert_single_tvm_to_pytorch(self, tvm_tensor: Any) -> "torch.Tensor":
"""Convert a single TVM Tensor to PyTorch tensor using faster DLPack converter."""
# pylint: disable=import-outside-toplevel
import torch

if isinstance(tvm_array, torch.Tensor):
return tvm_array
if not isinstance(tvm_array, Tensor):
return torch.tensor(tvm_array)
if isinstance(tvm_tensor, torch.Tensor):
return tvm_tensor
if not isinstance(tvm_tensor, Tensor):
return torch.tensor(tvm_tensor)

# 1. Try faster C++ DLPack converter
if _FASTER_DLPACK_EXTENSION is not None:
try:
return torch.from_dlpack(tvm_tensor)
except (AttributeError, ValueError):
pass # Fall through to the next method

# 2. Try standard DLPack conversion
try:
return torch.from_dlpack(tvm_array)
return torch.from_dlpack(tvm_tensor)
# pylint: disable=broad-exception-caught
except Exception as error:
print(f"Warning: DLPack conversion from TVM failed ({error}), using numpy fallback")
numpy_array = tvm_array.numpy()
numpy_array = tvm_tensor.numpy()
return torch.from_numpy(numpy_array)

def get_function(self, name: str) -> Optional[PackedFunc]:
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_base_py_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,4 +203,4 @@ def my_softmax(tensor, dim):


if __name__ == "__main__":
pytest.main([__file__])
tvm.testing.main()
52 changes: 4 additions & 48 deletions tests/python/relax/test_base_py_module_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,54 +420,6 @@ def safe_transform(data: T.handle, output: T.handle):
Output[i] = 0.0


if __name__ == "__main__":
# This allows the file to be run directly for debugging
# In normal pytest usage, these classes are automatically tested by TVMScript
print("All test modules defined successfully!")
print("TVMScript will automatically validate these modules during testing.")

# Demo the printer functionality
print("\n" + "=" * 60)
print("DEMO: BasePyModule Printer Functionality")
print("=" * 60)

# Test the printer with SimplePyFuncModule
try:
ir_mod = SimplePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)

print("\n1. Testing script() method:")
print("-" * 40)
script_output = module.script()
print(script_output[:500] + "..." if len(script_output) > 500 else script_output)

print("\n2. Testing show() method:")
print("-" * 40)
module.show()

print("\n3. Python functions found in pyfuncs:")
print("-" * 40)
if hasattr(ir_mod, "pyfuncs"):
for name, func in ir_mod.pyfuncs.items():
print(f" - {name}: {func}")
else:
print(" No pyfuncs attribute found")

except Exception as e:
print(f"Demo failed: {e}")
print("This is expected for testing-only TVMScript code.")

# Run all tests using tvm.testing.main()
print("\n" + "=" * 60)
print("Running all tests with tvm.testing.main()...")
print("=" * 60)

import tvm.testing

tvm.testing.main()


# Pytest test functions to verify the classes work correctly
def test_simple_pyfunc_module_creation():
"""Test that SimplePyFuncModule can be created."""
Expand Down Expand Up @@ -849,3 +801,7 @@ def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32

# Use numpy for comparison since we have numpy arrays
np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
tvm.testing.main()
Loading