Skip to content

Commit db8dc97

Browse files
[None][fix] Migrate to new cuda binding package name (#6700)
Signed-off-by: Yuan Tong <[email protected]>
1 parent 4ecda91 commit db8dc97

File tree

16 files changed

+80
-20
lines changed

16 files changed

+80
-20
lines changed

cpp/kernels/fmha_v2/fmha_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import subprocess
22

33
import pytest
4-
from cuda import cuda, nvrtc
4+
5+
try:
6+
from cuda.bindings import driver as cuda
7+
from cuda.bindings import nvrtc
8+
except ImportError:
9+
from cuda import cuda, nvrtc
510

611

712
def ASSERT_DRV(err):

tensorrt_llm/_ipc_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,20 @@
1717
import sys
1818
from typing import List, Tuple
1919

20-
from cuda import cuda, cudart
21-
from cuda.cudart import cudaError_t
20+
try:
21+
from cuda.bindings import driver as cuda
22+
from cuda.bindings import runtime as cudart
23+
except ImportError:
24+
from cuda import cuda, cudart
2225

2326
from ._utils import mpi_comm
2427
from .logger import logger
2528
from .mapping import Mapping
2629

2730

28-
def _raise_if_error(error: cudaError_t | cuda.CUresult):
29-
if isinstance(error, cudaError_t):
30-
if error != cudaError_t.cudaSuccess:
31+
def _raise_if_error(error: cudart.cudaError_t | cuda.CUresult):
32+
if isinstance(error, cudart.cudaError_t):
33+
if error != cudart.cudaError_t.cudaSuccess:
3134
raise RuntimeError(f"CUDA Runtime API error: {repr(error)}")
3235
if isinstance(error, cuda.CUresult):
3336
if error != cuda.CUresult.CUDA_SUCCESS:

tensorrt_llm/_mnnvl_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222
import pynvml
2323
import torch
24-
from cuda import cuda
24+
25+
try:
26+
from cuda.bindings import driver as cuda
27+
except ImportError:
28+
from cuda import cuda
2529

2630
from ._dlpack_utils import pack_strided_memory
2731
from ._utils import mpi_comm

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from typing import Dict, Iterable, List, Optional, Tuple, Union
1212

1313
import torch
14-
from cuda import cudart
14+
15+
try:
16+
from cuda.bindings import runtime as cudart
17+
except ImportError:
18+
from cuda import cudart
1519

1620
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1721
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager

tensorrt_llm/auto_parallel/cluster_info.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import pynvml
77
import torch
8-
from cuda import cudart
8+
9+
try:
10+
from cuda.bindings import runtime as cudart
11+
except ImportError:
12+
from cuda import cudart
913

1014
from tensorrt_llm._utils import DictConversion
1115
from tensorrt_llm.logger import logger

tensorrt_llm/runtime/generation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
import torch
3030
import tensorrt as trt
3131
# isort: on
32-
from cuda import cudart
32+
try:
33+
from cuda.bindings import runtime as cudart
34+
except ImportError:
35+
from cuda import cudart
3336

3437
from tensorrt_llm.runtime.memory_pools.memory_pools_allocator import \
3538
MemoryPoolsAllocator

tensorrt_llm/runtime/multimodal_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from typing import Optional, Tuple
1414

1515
import torch.nn.functional as F
16-
from cuda import cudart
16+
17+
try:
18+
from cuda.bindings import runtime as cudart
19+
except ImportError:
20+
from cuda import cudart
21+
1722
from huggingface_hub import hf_hub_download
1823
from PIL import Image, UnidentifiedImageError
1924
from safetensors import safe_open

tests/integration/defs/sysinfo/get_sysinfo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424

2525
import psutil
2626
import pynvml
27-
from cuda import cuda
27+
28+
try:
29+
from cuda.bindings import driver as cuda
30+
except ImportError:
31+
from cuda import cuda
2832

2933
# Logger
3034
logger = logging.getLogger(__name__)

tests/microbenchmarks/all_reduce.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
# isort: off
1919
import torch
2020
# isort: on
21-
from cuda import cudart
21+
try:
22+
from cuda.bindings import runtime as cudart
23+
except ImportError:
24+
from cuda import cudart
2225

2326
import tensorrt_llm as tllm
2427
from tensorrt_llm import Mapping

tests/microbenchmarks/build_time_benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import traceback
88

99
import tensorrt as trt
10-
from cuda import cudart
10+
11+
try:
12+
from cuda.bindings import runtime as cudart
13+
except ImportError:
14+
from cuda import cudart
1115

1216
import tensorrt_llm
1317
from tensorrt_llm import (AutoConfig, AutoModelForCausalLM, BuildConfig,

0 commit comments

Comments
 (0)