diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 15539ecac416..adbceefedd8a 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -4,10 +4,6 @@ # --------------------------------------- # Note: import order is significant here. -# TODO: torch needs to be imported first -# or pybind11 shows `munmap_chunk(): invalid pointer` -import torch # noqa: F401 - # submodules from .runtime import ( autotune, diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 135f71f6cd63..8c1dacbe12bd 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,8 +11,6 @@ from pathlib import Path from typing import Any, Tuple -import torch - import triton import triton._C.libtriton.triton as _triton from ..runtime import driver @@ -324,6 +322,10 @@ def _is_cuda(arch): def get_architecture_descriptor(capability): + try: + import torch + except ImportError: + raise ImportError("Triton requires PyTorch to be installed") if capability is None: if torch.version.hip is None: device = triton.runtime.jit.get_current_device() diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 62e2af7b5b63..5047a866b0ad 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -3,8 +3,6 @@ from functools import wraps from typing import List, Optional, Sequence, Tuple, TypeVar -import torch - import triton from . import core as tl from triton._C.libtriton.triton import ir @@ -1183,6 +1181,10 @@ def dot(lhs: tl.tensor, allow_tf32: bool, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + try: + import torch + except ImportError: + raise ImportError("Triton requires PyTorch to be installed") if torch.version.hip is None: device = triton.runtime.jit.get_current_device() capability = triton.runtime.jit.get_device_capability(device)